diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -33,11 +33,22 @@ /// A map for looking up bufferized function types. DenseMap bufferizedFunctionTypes; + + /// A mapping of return values to equivalent BlockArguments. + DenseMap equivalentReturnValToBBArg; }; } // namespace static bool isaTensor(Type t) { return t.isa(); } +/// If `value` is a memref::CastOp, return its source. Otherwise, return +/// `value` directly. +static Value getNonCastedValue(Value value) { + while (auto castOp = value.getDefiningOp()) + value = castOp.source(); + return value; +} + /// Remove the attribute that triggers inplace bufferization on a FuncOp /// argument `bbArg`. static void removeBufferizationFuncArguments(BlockArgument bbArg) { @@ -112,62 +123,40 @@ return it2.first->second; } -/// Return the op with Allocate MemoryEffect if `v` is equivalent to such an -/// an op. Return null otherwise. -static Operation *getEquivalentAlloc(Value value, - const BufferizationAliasInfo &aliasInfo) { - Operation *res = nullptr; - aliasInfo.applyOnEquivalenceClass(value, [&](Value v) { - if (!res) - if (auto interface = - dyn_cast_or_null(v.getDefiningOp())) - if (auto effect = - interface.getEffectOnValue(v)) - res = v.getDefiningOp(); - }); - return res; -} +/// Store function BlockArguments that are equivalent to a returned value in +/// the given ModuleBufferizationState. +static void populateEquivalentFuncOpBBArgs(FuncOp funcOp, + ModuleBufferizationState &state) { + // Support only single return-terminated block in the function. + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); -/// Return the first argument of the enclosing FuncOp that is equivalent to `v`. -/// Return null if no such bbArg can be found. -static BlockArgument -getEquivalentEnclosingFuncBBArg(Value v, - const BufferizationAliasInfo &aliasInfo) { - if (!v.getType().isa()) - return nullptr; - Operation *op = v.getParentBlock()->getParentOp(); - FuncOp funcOp = dyn_cast(op); - if (!funcOp) - funcOp = op->getParentOfType(); - assert(funcOp && "expected non-null FuncOp"); - for (BlockArgument bbArg : funcOp.getArguments()) { - if (!bbArg.getType().isa()) - continue; - if (aliasInfo.areEquivalentBufferizedValues(v, bbArg)) - return bbArg; - } - return nullptr; + for (Value returnVal : returnOp.operands()) + if (returnVal.getType().isa()) + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) + if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg)) + state.equivalentReturnValToBBArg[returnVal] = bbArg; } /// Rewrite the `funcOp` arguments analysis return values and terminator into /// buffer form (using the canonical memref layout for now), according to the /// inPlace-bufferizable information of the function arguments. +/// /// This relies on a buffer equivalence analysis of each return operand. When a -/// result buffer is equivalent to: -/// 1. a BlockArgument of `funcOp`, it can be dropped from the return values -/// and becomes inplaceable at all callers. This assumes all CallOp perform -/// the necessary work to clone operands so as to make them inplaceable. -// Reliance on this logic will need to be relaxed in thefuture. -/// 2. an op with an Alloc effect, this currently fails bufferization but is a -/// candidate for hoisting and creating a new inplace operand at all caller -/// sites. -/// 3. if such a hoisting for 2. is not possible (e.g. data-dependent that -/// prevents hoisting), this is currently unsupported and will require a -/// refcounted buffer type. -static LogicalResult bufferizeFuncOpBoundary( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DenseMap &bufferizedFunctionTypes) { +/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be +/// dropped from the return values and becomes inplaceable at all callers. This +/// assumes all CallOp perform the necessary work to clone operands so as to +/// make them inplaceable. Reliance on this logic will need to be relaxed in the +/// future. +/// +/// Note: Returning a memref currently fails bufferization. If such memrefs +/// originate from an op with an Alloc effect, they could be hoisted in the +/// future. +static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, + ModuleBufferizationState &state) { LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); + BufferizationAliasInfo &aliasInfo = state.aliasInfo; // If nothing to do then we are done. if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) && @@ -196,9 +185,9 @@ if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) return funcOp->emitError() << "cannot bufferize bodiless function that " << "returns a tensor"; - FunctionType bufferizedFuncType = - getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(), - TypeRange{}, bufferizedFunctionTypes); + FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( + funcOp, funcOp.getType().getInputs(), TypeRange{}, + state.bufferizedFunctionTypes); funcOp.setType(bufferizedFuncType); LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp); return success(); @@ -211,37 +200,27 @@ // 1. For each FuncOp result, keep track of which inplace argument it reuses. SmallVector returnValues; for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Value returnVal = returnOperand.get(); + // If not a renturn tensor type just forward it. - if (!returnOperand.get().getType().isa()) { - returnValues.push_back(returnOperand.get()); + if (!returnVal.getType().isa()) { + returnValues.push_back(returnVal); continue; } // If return operand is equivalent to some bbArg, no need to return it. - Value returnVal = returnOperand.get(); - if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) + if (state.equivalentReturnValToBBArg.count(returnVal)) continue; - // TODO: Need to hoist above function boundary. - if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) { - returnValues.push_back(allocOp->getResult(0)); - continue; - } - - // Other cases legitimately need to return a tensor, this is currently not - // supported. For instance, if hoisting across function boundary has - // failed, it may be due to e.g. data-dependent sizes. In such a case, we - // would need a better type than memref. - int64_t returnIdx = returnOperand.getOperandNumber(); - return returnOp->emitError() - << "buffer result #" << returnIdx << " not produced by an alloc\n"; + // Cast values at the call site if necessary. + returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal))); } // 2. Rewrite the terminator without the inPlace bufferizable values. ValueRange retValues{returnValues}; FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( funcOp, funcOp.getType().getInputs(), retValues.getTypes(), - bufferizedFunctionTypes); + state.bufferizedFunctionTypes); OpBuilder b(returnOp); b.create(returnOp.getLoc(), returnValues); returnOp->erase(); @@ -493,6 +472,7 @@ FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && "expected Callop to a FuncOp"); + auto &moduleState = static_cast(state); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -505,11 +485,10 @@ // semantics is ill-defined. // - if the callee has a body, we perform inter-procedural equivalence // analysis. When successful, a result folds onto an operand. When - // unsuccessful, additional work is needed to either: + // unsuccessful, additional work is needed (TODO) to either: // * hoist a result into an inplaceable operand or // * devise a better representation to truly return a buffer. SmallVector resultTypes; - SmallVector hoistedArguments; if (funcOp.body().empty()) { if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) return callOp->emitError() @@ -528,8 +507,9 @@ // If return operand is equivalent to some bbArg, no need to return it. Value returnVal = returnOperand.get(); - if (BlockArgument bbArg = - getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) { + if (moduleState.equivalentReturnValToBBArg.count(returnVal)) { + BlockArgument bbArg = + moduleState.equivalentReturnValToBBArg[returnVal]; Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); int64_t idx = bbArg.getArgNumber(); Value buffer = state.lookupBuffer(callOp->getOperand(idx)); @@ -550,35 +530,17 @@ continue; } - // TODO: Need to hoist above function boundary. - if (Operation *allocOp = - getEquivalentAlloc(returnVal, state.aliasInfo)) { - hoistedArguments.push_back(allocOp->getResult(0)); - continue; - } - - // Other cases legitimately need to return a tensor, this is currently - // not supported. For instance, if hoisting across function boundary has - // failed, it may be due to e.g. data-dependent sizes. In such a case, - // we would we need a better type than memref. resultTypes.push_back(returnType); - - int64_t returnIdx = returnOperand.getOperandNumber(); - return returnOp->emitError() << "buffer result #" << returnIdx - << " not produced by an alloc\n"; } } // 2. Compute bufferized FunctionType. SmallVector argumentTypes{callOp->getOperandTypes()}; - ValueRange hoistedArgs{hoistedArguments}; - llvm::append_range(argumentTypes, hoistedArgs.getTypes()); // Get the bufferized FunctionType for funcOp or construct it if not yet // available. - // TODO: Assert that `state` is a ModuleBufferizationState. - FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( - funcOp, argumentTypes, resultTypes, - static_cast(state).bufferizedFunctionTypes); + FunctionType bufferizedFuncType = + getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes, + moduleState.bufferizedFunctionTypes); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. SmallVector newOperands; @@ -710,6 +672,8 @@ // Analyze and bufferize funcOp. if (failed(runComprehensiveBufferize(funcOp, options, state))) return failure(); + + populateEquivalentFuncOpBBArgs(funcOp, state); } if (options.testAnalysisOnly) @@ -718,8 +682,7 @@ for (FuncOp funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo, - state.bufferizedFunctionTypes))) + if (failed(bufferizeFuncOpBoundary(funcOp, state))) return failure(); if (!options.allowReturnMemref && diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -110,6 +110,7 @@ // ----- +// expected-error @+1 {{memref return type is unsupported}} func @extract_slice_fun(%A : tensor {linalg.inplaceable = true}) -> tensor<4xf32> { @@ -121,7 +122,6 @@ // argument aliasing). %r0 = tensor.extract_slice %A[0][4][1] : tensor to tensor<4xf32> - // expected-error @+1 {{buffer result #0 not produced by an alloc}} return %r0: tensor<4xf32> }