Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Differential D114508 Diff 389894 mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
Show All 21 Lines | |||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X) | #define LDBG(X) LLVM_DEBUG(DBGS() << X) | ||||
using namespace mlir; | using namespace mlir; | ||||
using namespace linalg; | using namespace linalg; | ||||
using namespace tensor; | using namespace tensor; | ||||
using namespace comprehensive_bufferize; | using namespace comprehensive_bufferize; | ||||
namespace { | namespace { | ||||
/// A specialization of BufferizationState that keeps track of additional | /// Extra bufferization state that is required for bufferization of function | ||||
/// state required for bufferization of function boundaries. | /// boundaries. | ||||
struct ModuleBufferizationState : public BufferizationState { | struct ModuleBufferizationState : public DialectBufferizationState { | ||||
using BufferizationState::BufferizationState; | |||||
/// A map for looking up bufferized function types. | /// A map for looking up bufferized function types. | ||||
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes; | DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes; | ||||
/// A mapping of return values to equivalent BlockArguments. | /// A mapping of return values to equivalent BlockArguments. | ||||
DenseMap<Value, BlockArgument> equivalentReturnValToBBArg; | DenseMap<Value, BlockArgument> equivalentReturnValToBBArg; | ||||
}; | }; | ||||
} // namespace | } // namespace | ||||
static ModuleBufferizationState & | |||||
getModuleBufferizationState(BufferizationState &state) { | |||||
return state.getDialectState<ModuleBufferizationState>( | |||||
StandardOpsDialect::getDialectNamespace()); | |||||
} | |||||
static bool isaTensor(Type t) { return t.isa<TensorType>(); } | static bool isaTensor(Type t) { return t.isa<TensorType>(); } | ||||
/// If `value` is a memref::CastOp, return its source. Otherwise, return | /// If `value` is a memref::CastOp, return its source. Otherwise, return | ||||
/// `value` directly. | /// `value` directly. | ||||
static Value getNonCastedValue(Value value) { | static Value getNonCastedValue(Value value) { | ||||
while (auto castOp = value.getDefiningOp<memref::CastOp>()) | while (auto castOp = value.getDefiningOp<memref::CastOp>()) | ||||
value = castOp.source(); | value = castOp.source(); | ||||
return value; | return value; | ||||
▲ Show 20 Lines • Show All 71 Lines • ▼ Show 20 Lines | auto it2 = bufferizedFunctionTypes.try_emplace( | ||||
resultTypes)); | resultTypes)); | ||||
LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n"); | LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n"); | ||||
return it2.first->second; | return it2.first->second; | ||||
} | } | ||||
/// Store function BlockArguments that are equivalent to a returned value in | /// Store function BlockArguments that are equivalent to a returned value in | ||||
/// the given ModuleBufferizationState. | /// the given ModuleBufferizationState. | ||||
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp, | static void populateEquivalentFuncOpBBArgs(FuncOp funcOp, | ||||
ModuleBufferizationState &state) { | BufferizationState &state) { | ||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state); | |||||
// Support only single return-terminated block in the function. | // Support only single return-terminated block in the function. | ||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); | ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); | ||||
assert(returnOp && "expected func with single return op"); | assert(returnOp && "expected func with single return op"); | ||||
for (Value returnVal : returnOp.operands()) | for (Value returnVal : returnOp.operands()) | ||||
if (returnVal.getType().isa<RankedTensorType>()) | if (returnVal.getType().isa<RankedTensorType>()) | ||||
for (BlockArgument bbArg : funcOp.getArguments()) | for (BlockArgument bbArg : funcOp.getArguments()) | ||||
if (bbArg.getType().isa<RankedTensorType>()) | if (bbArg.getType().isa<RankedTensorType>()) | ||||
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg)) | if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg)) | ||||
state.equivalentReturnValToBBArg[returnVal] = bbArg; | moduleState.equivalentReturnValToBBArg[returnVal] = bbArg; | ||||
} | } | ||||
/// Rewrite the `funcOp` arguments analysis return values and terminator into | /// Rewrite the `funcOp` arguments analysis return values and terminator into | ||||
/// buffer form (using the canonical memref layout for now), according to the | /// buffer form (using the canonical memref layout for now), according to the | ||||
/// inPlace-bufferizable information of the function arguments. | /// inPlace-bufferizable information of the function arguments. | ||||
/// | /// | ||||
/// This relies on a buffer equivalence analysis of each return operand. When a | /// This relies on a buffer equivalence analysis of each return operand. When a | ||||
/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be | /// result buffer is equivalent to a BlockArgument of `funcOp`, it can be | ||||
/// dropped from the return values and becomes inplaceable at all callers. This | /// dropped from the return values and becomes inplaceable at all callers. This | ||||
/// assumes all CallOp perform the necessary work to clone operands so as to | /// assumes all CallOp perform the necessary work to clone operands so as to | ||||
/// make them inplaceable. Reliance on this logic will need to be relaxed in the | /// make them inplaceable. Reliance on this logic will need to be relaxed in the | ||||
/// future. | /// future. | ||||
/// | /// | ||||
/// Note: Returning a memref currently fails bufferization. If such memrefs | /// Note: Returning a memref currently fails bufferization. If such memrefs | ||||
/// originate from an op with an Alloc effect, they could be hoisted in the | /// originate from an op with an Alloc effect, they could be hoisted in the | ||||
/// future. | /// future. | ||||
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, | static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, | ||||
ModuleBufferizationState &state) { | BufferizationState &state) { | ||||
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); | LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); | ||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state); | |||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo; | BufferizationAliasInfo &aliasInfo = state.aliasInfo; | ||||
// If nothing to do then we are done. | // If nothing to do then we are done. | ||||
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) && | if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) && | ||||
!llvm::any_of(funcOp.getType().getResults(), isaTensor)) | !llvm::any_of(funcOp.getType().getResults(), isaTensor)) | ||||
return success(); | return success(); | ||||
// Get the bufferized FunctionType for funcOp or construct it if not yet | // Get the bufferized FunctionType for funcOp or construct it if not yet | ||||
Show All 15 Lines | static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, | ||||
// bufferization contract they want to enforce atm. | // bufferization contract they want to enforce atm. | ||||
// As a consequence, only support functions that don't return any tensor atm. | // As a consequence, only support functions that don't return any tensor atm. | ||||
if (funcOp.getBody().empty()) { | if (funcOp.getBody().empty()) { | ||||
if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) | if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) | ||||
return funcOp->emitError() << "cannot bufferize bodiless function that " | return funcOp->emitError() << "cannot bufferize bodiless function that " | ||||
<< "returns a tensor"; | << "returns a tensor"; | ||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( | FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( | ||||
funcOp, funcOp.getType().getInputs(), TypeRange{}, | funcOp, funcOp.getType().getInputs(), TypeRange{}, | ||||
state.bufferizedFunctionTypes); | moduleState.bufferizedFunctionTypes); | ||||
funcOp.setType(bufferizedFuncType); | funcOp.setType(bufferizedFuncType); | ||||
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp); | LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp); | ||||
return success(); | return success(); | ||||
} | } | ||||
// Support only single return-terminated block in the function. | // Support only single return-terminated block in the function. | ||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); | ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); | ||||
assert(returnOp && "expected func with single return op"); | assert(returnOp && "expected func with single return op"); | ||||
// 1. For each FuncOp result, keep track of which inplace argument it reuses. | // 1. For each FuncOp result, keep track of which inplace argument it reuses. | ||||
SmallVector<Value> returnValues; | SmallVector<Value> returnValues; | ||||
for (OpOperand &returnOperand : returnOp->getOpOperands()) { | for (OpOperand &returnOperand : returnOp->getOpOperands()) { | ||||
Value returnVal = returnOperand.get(); | Value returnVal = returnOperand.get(); | ||||
// If not a renturn tensor type just forward it. | // If not a renturn tensor type just forward it. | ||||
if (!returnVal.getType().isa<RankedTensorType>()) { | if (!returnVal.getType().isa<RankedTensorType>()) { | ||||
returnValues.push_back(returnVal); | returnValues.push_back(returnVal); | ||||
continue; | continue; | ||||
} | } | ||||
// If return operand is equivalent to some bbArg, no need to return it. | // If return operand is equivalent to some bbArg, no need to return it. | ||||
if (state.equivalentReturnValToBBArg.count(returnVal)) | if (moduleState.equivalentReturnValToBBArg.count(returnVal)) | ||||
continue; | continue; | ||||
// Cast values at the call site if necessary. | // Cast values at the call site if necessary. | ||||
returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal))); | returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal))); | ||||
} | } | ||||
// 2. Rewrite the terminator without the inPlace bufferizable values. | // 2. Rewrite the terminator without the inPlace bufferizable values. | ||||
ValueRange retValues{returnValues}; | ValueRange retValues{returnValues}; | ||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( | FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( | ||||
funcOp, funcOp.getType().getInputs(), retValues.getTypes(), | funcOp, funcOp.getType().getInputs(), retValues.getTypes(), | ||||
state.bufferizedFunctionTypes); | moduleState.bufferizedFunctionTypes); | ||||
OpBuilder b(returnOp); | OpBuilder b(returnOp); | ||||
b.create<ReturnOp>(returnOp.getLoc(), returnValues); | b.create<ReturnOp>(returnOp.getLoc(), returnValues); | ||||
returnOp->erase(); | returnOp->erase(); | ||||
// 3. Rewrite the bbArgs. | // 3. Rewrite the bbArgs. | ||||
// Iterate on the original `numArgs` and replace them in order. | // Iterate on the original `numArgs` and replace them in order. | ||||
// This guarantees the argument order still matches after the rewrite. | // This guarantees the argument order still matches after the rewrite. | ||||
Block &frontBlock = funcOp.body().front(); | Block &frontBlock = funcOp.body().front(); | ||||
▲ Show 20 Lines • Show All 236 Lines • ▼ Show 20 Lines | struct CallOpInterface | ||||
/// marked inplaceable. For now, it is the responsibility of the `callOp` | /// marked inplaceable. For now, it is the responsibility of the `callOp` | ||||
/// bufferization to allow FuncOp that are inplaceable to write inPlace. | /// bufferization to allow FuncOp that are inplaceable to write inPlace. | ||||
LogicalResult bufferize(Operation *op, OpBuilder &b, | LogicalResult bufferize(Operation *op, OpBuilder &b, | ||||
BufferizationState &state) const { | BufferizationState &state) const { | ||||
CallOp callOp = cast<CallOp>(op); | CallOp callOp = cast<CallOp>(op); | ||||
FuncOp funcOp = getCalledFunction(callOp); | FuncOp funcOp = getCalledFunction(callOp); | ||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp && | assert(isa<CallOp>(callOp.getOperation()) && funcOp && | ||||
"expected Callop to a FuncOp"); | "expected Callop to a FuncOp"); | ||||
auto &moduleState = static_cast<ModuleBufferizationState &>(state); | ModuleBufferizationState &moduleState = getModuleBufferizationState(state); | ||||
// Take a guard before anything else. | // Take a guard before anything else. | ||||
OpBuilder::InsertionGuard g(b); | OpBuilder::InsertionGuard g(b); | ||||
b.setInsertionPoint(callOp); | b.setInsertionPoint(callOp); | ||||
// 1. Filter return types: | // 1. Filter return types: | ||||
// - if the callee is bodiless / external, we cannot inspect it and we | // - if the callee is bodiless / external, we cannot inspect it and we | ||||
// cannot assume anything. We can just assert that it does not return a | // cannot assume anything. We can just assert that it does not return a | ||||
▲ Show 20 Lines • Show All 158 Lines • ▼ Show 20 Lines | |||||
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( | LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( | ||||
ModuleOp moduleOp, const BufferizationOptions &options) { | ModuleOp moduleOp, const BufferizationOptions &options) { | ||||
SmallVector<FuncOp> orderedFuncOps; | SmallVector<FuncOp> orderedFuncOps; | ||||
DenseMap<FuncOp, DenseSet<Operation *>> callerMap; | DenseMap<FuncOp, DenseSet<Operation *>> callerMap; | ||||
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) | if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) | ||||
return failure(); | return failure(); | ||||
ModuleBufferizationState state(moduleOp, *options.allocationFns); | BufferizationState state(moduleOp, *options.allocationFns); | ||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo; | BufferizationAliasInfo &aliasInfo = state.aliasInfo; | ||||
// Interestingly, all function args that are not visible outside of a module | // Interestingly, all function args that are not visible outside of a module | ||||
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized | // can be fully bufferized inplace by guaranteeing the CallOp is bufferized | ||||
// inplace. Therefore, we just bufferize funcOp as if none of its results were | // inplace. Therefore, we just bufferize funcOp as if none of its results were | ||||
// inplaceable, detect which operands are cloned internally and decide what to | // inplaceable, detect which operands are cloned internally and decide what to | ||||
// do at call sites. | // do at call sites. | ||||
for (FuncOp funcOp : orderedFuncOps) { | for (FuncOp funcOp : orderedFuncOps) { | ||||
▲ Show 20 Lines • Show All 51 Lines • Show Last 20 Lines |