diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -197,6 +197,8 @@ /// is returned regardless of whether it is a memory write or not. Value findLastPrecedingWrite(Value value); +struct BufferizationState; + /// Callback functions that are used to allocate/deallocate/copy memory buffers. /// Comprehensive Bufferize provides default implementations of these functions. // TODO: Could be replaced with a "bufferization strategy" object with virtual @@ -207,8 +209,7 @@ using DeallocationFn = std::function; using MemCpyFn = std::function; using CreateAllocDeallocFn = - std::function; + std::function; AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn) @@ -230,13 +231,40 @@ CreateAllocDeallocFn createAllocDeallocFn; }; +/// BufferizationState keeps track of bufferization state and provides access to +/// the results of the analysis. +struct BufferizationState { + BufferizationState(BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFns, + BlockAndValueMapping &tensorToBufferMap) + : aliasInfo(aliasInfo), allocationFns(allocationFns), + tensorToBufferMap(tensorToBufferMap) {} + + /// Map tensor values to memref buffers. + void mapBuffer(ValueRange tensors, ValueRange buffers); + + /// Map a tensor value to a memref buffer. + void mapBuffer(Value tensor, Value buffer); + + /// Lookup the memref buffer that is associated to the given tensor value. + /// Asserts if no buffer is associated. + Value lookupBuffer(Value tensor) const; + + /// `aliasInfo` keeps track of aliasing and equivalent values. + BufferizationAliasInfo &aliasInfo; + + /// `allocationFns` contains helper functions for creating alloc ops, dealloc + /// ops and memcpy ops. + AllocationCallbacks &allocationFns; + + /// The mapping of tensors to buffers. + BlockAndValueMapping &tensorToBufferMap; +}; + /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. -Value getResultBuffer(OpBuilder &b, OpResult result, - const BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks allocationFns); +Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state); } // namespace comprehensive_bufferize } // namespace linalg @@ -280,9 +308,7 @@ bool isWritable(Operation *op, Value value) const { return false; } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto isaTensor = [](Type t) { return t.isa(); }; if (any_of(op->getOperandTypes(), isaTensor) || any_of(op->getResultTypes(), isaTensor)) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -160,8 +160,6 @@ llvm_unreachable("bufferRelation not implemented"); }] >, - // TODO: Simplify method signature: Pass an OpBuilder and a - // BufferizationState object. InterfaceMethod< /*desc=*/[{ Bufferize this op, i.e., rewrite it into a memref-based equivalent. @@ -171,9 +169,7 @@ /*retType=*/"LogicalResult", /*methodName=*/"bufferize", /*args=*/(ins "OpBuilder &":$b, - "BlockAndValueMapping &":$bvm, - "BufferizationAliasInfo &":$aliasInfo, - "AllocationCallbacks &":$allocationFn), + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ llvm_unreachable("bufferize not implemented"); diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -27,6 +27,8 @@ // TODO: from some HW description. static constexpr int64_t kBufferAlignments = 128; +struct BufferizationState; + /// Analyze the `ops` to determine which OpResults are inplaceable. LogicalResult inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, @@ -55,9 +57,7 @@ /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be /// non-null if `op` is a CallOpInterface (resp. GlobalCreator). LogicalResult -bufferizeOp(Operation *op, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks allocationFns, +bufferizeOp(Operation *op, BufferizationState &state, DenseMap *bufferizedFunctionTypes = nullptr); /// Register external models implemented for the `BufferizableOpInterface`. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -7,8 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" #include "llvm/Support/Debug.h" namespace mlir { @@ -319,30 +322,28 @@ /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. Value mlir::linalg::comprehensive_bufferize::getResultBuffer( - OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns) { + OpBuilder &b, OpResult result, BufferizationState &state) { OpBuilder::InsertionGuard guard(b); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); OpOperand *opOperand = aliasingOperands.front(); Value operand = opOperand->get(); - Value operandBuffer = bvm.lookupOrNull(operand); - assert(operandBuffer && "operand buffer not found"); + Value operandBuffer = state.lookupBuffer(operand); // Make sure that all OpOperands are the same buffer. If this is not the case, // we would have to materialize a memref value. // TODO: Should be looking for checking for "equivalent buffers" instead of // operator== here, but equivalent buffers for scf.if yield values are not // set up yet. if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) { - return bvm.lookup(o->get()) == operandBuffer; + return state.lookupBuffer(o->get()) == operandBuffer; })) { op->emitError("result buffer is ambiguous"); return Value(); } // If bufferizing out-of-place, allocate a new buffer. - if (!aliasInfo.isInPlace(result)) { + if (!state.aliasInfo.isInPlace(result)) { // Ops with multiple aliasing operands can currently not bufferize // out-of-place. assert( @@ -350,8 +351,8 @@ "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); Location loc = op->getLoc(); // Allocate the result buffer. - Value resultBuffer = allocationFns.createAllocDeallocFn( - b, loc, operand, aliasInfo, allocationFns); + Value resultBuffer = + state.allocationFns.createAllocDeallocFn(b, loc, operand, state); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -373,7 +374,7 @@ if (!skipCopy) { // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); - allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); + state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); } return resultBuffer; } @@ -381,3 +382,39 @@ // Bufferizing in-place. No need to allocate a new buffer. return operandBuffer; } + +//===----------------------------------------------------------------------===// +// Bufferization-specific BlockAndValueMapping support with debugging. +//===----------------------------------------------------------------------===// + +/// Wrapper for better debugging. +void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer( + ValueRange tensors, ValueRange buffers) { + assert(!tensors.empty() && "unexpected empty tensors"); + return tensorToBufferMap.map(tensors, buffers); +} + +/// Wrapper for better debugging. +void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer( + Value tensor, Value buffer) { + assert(tensor && "unexpected empty tensor"); + assert(tensor.getType().isa() && "unexpected non-tensor type"); + return tensorToBufferMap.map(tensor, buffer); +} + +/// Wrapper for better debugging. +Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer( + Value tensor) const { + // TODO: if key comes from bbArg, forward. + assert(tensor.getType().isa() && "unexpected non-tensor type"); + Value v = tensorToBufferMap.lookupOrNull(tensor); + + if (!v) { + // Dump tensor for easier debugging. + tensor.dump(); + llvm_unreachable("tensor is not mapped"); + return Value(); + } + + return v; +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -172,47 +172,6 @@ return returnOp; } -//===----------------------------------------------------------------------===// -// Bufferization-specific BlockAndValueMapping support with debugging. -//===----------------------------------------------------------------------===// - -/// Wrapper for better debugging. -static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) { - assert(!keys.empty() && "Unexpected empty keys"); - LDBG("\n\tMap: " << printValueInfo(keys.front()) - << "\n\tto: " << printValueInfo(values.front()) << '\n'); - return bvm.map(keys, values); -} - -/// Wrapper for better debugging. -static void map(BlockAndValueMapping &bvm, Value key, Value value) { - LDBG("\n\tMap: " << printValueInfo(key) << "\n\tto: " << printValueInfo(value) - << '\n'); - return bvm.map(key, value); -} - -/// Wrapper for better debugging. -static Value lookup(const BlockAndValueMapping &bvm, Value key) { - // TODO: if key comes from bbArg, forward. - assert(key.getType().isa()); - Value v = bvm.lookupOrNull(key); - if (v) - return v; - - Operation *parentOp; - if (auto bbArg = key.dyn_cast()) { - if (isa(key.getParentBlock()->getParentOp())) - parentOp = key.getParentBlock()->getParentOp(); - else - parentOp = key.getParentBlock()->getParentOp()->getParentOfType(); - } else { - parentOp = key.getDefiningOp()->getParentOfType(); - } - LDBG("In func:\n" << *parentOp << "\nNO VALUE FOR KEY: " << key << '\n'); - (void)parentOp; - return Value(); -} - //===----------------------------------------------------------------------===// // Bufferization-specific attribute manipulation. // These are for testing and debugging only. Bufferization information is @@ -878,8 +837,7 @@ /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. static Value createNewAllocDeallocPairForShapedValue( - OpBuilder &b, Location loc, Value shapedValue, - BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) { + OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -891,19 +849,19 @@ MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); Optional allocated = - allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); + state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape); // TODO: For now just assert the value is returned. Eventually need to // error-propagate. assert(allocated && "allocation failed"); Value casted = allocated.getValue(); if (memRefType && memRefType != allocMemRefType) { casted = b.create(loc, memRefType, allocated.getValue()); - aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); + state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); } // 2. Create memory deallocation. b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - allocationFns.deallocationFn(b, loc, allocated.getValue()); + state.allocationFns.deallocationFn(b, loc, allocated.getValue()); return casted; } @@ -915,8 +873,7 @@ /// inplaceable. For now, it is the responsibility of the `callOp` bufferization /// to allow FuncOp that are inplaceable to write inPlace. static LogicalResult -bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns, +bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state, DenseMap &bufferizedFunctionTypes) { FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && @@ -962,14 +919,13 @@ // If return operand is equivalent to some bbArg, no need to return it. Value returnVal = returnOperand.get(); if (BlockArgument bbArg = - getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) { + getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) { Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); int64_t idx = bbArg.getArgNumber(); - Value buffer = lookup(bvm, callOp->getOperand(idx)); - assert(buffer && "expected bufferized value"); + Value buffer = state.lookupBuffer(callOp->getOperand(idx)); // Add CallOp operand/result equivalence: this is interprocedural info. - aliasInfo.insertNewBufferEquivalence(oldRes, buffer); - map(bvm, oldRes, buffer); + state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer); + state.mapBuffer(oldRes, buffer); // Add a TensorLoadOp to kill all uses of the CallOp return. // Replace all uses of the CallOp results so we can erase the CallOp. // This TensorLoadOp must fold/DCE away or bufferization should be @@ -978,13 +934,13 @@ b.create(callOp.getLoc(), buffer); oldRes.replaceAllUsesWith(tensorLoad); // Add new op equivalence info. - aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer); - map(bvm, tensorLoad, buffer); + state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer); + state.mapBuffer(tensorLoad, buffer); continue; } // TODO: Need to hoist above function boundary. - if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) { + if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) { hoistedArguments.push_back(allocOp->getResult(0)); continue; } @@ -1023,8 +979,7 @@ // Tensor operands are guaranteed to have been buferized. int64_t idx = opOperand.getOperandNumber(); - Value buffer = lookup(bvm, tensorOperand); - assert(buffer && "expected bufferized value"); + Value buffer = state.lookupBuffer(tensorOperand); // Caller / callee type mistmatch is handled with a CastOp. auto memRefType = bufferizedFuncType.getInput(idx); @@ -1037,8 +992,8 @@ Value castBuffer = b.create(callOp.getLoc(), memRefType, buffer); // Add new op equivalence info. - aliasInfo.insertNewBufferEquivalence(castBuffer, buffer); - map(bvm, tensorOperand, castBuffer); + state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer); + state.mapBuffer(tensorOperand, castBuffer); buffer = castBuffer; } newOperands.push_back(buffer); @@ -1054,9 +1009,7 @@ /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { + BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); @@ -1072,8 +1025,8 @@ : getContiguousOrUnrankedMemRefType(tensorType); Value bufferCast = b.create(funcOp.getLoc(), memRefType, bbArg); - aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); - map(bvm, bbArg, bufferCast); + state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); + state.mapBuffer(bbArg, bufferCast); } return success(); } @@ -1230,8 +1183,7 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp( - Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - AllocationCallbacks allocationFns, + Operation *op, BufferizationState &state, DenseMap *bufferizedFunctionTypes) { OpBuilder b(op->getContext()); @@ -1241,8 +1193,7 @@ if (!bufferizedFunctionTypes) llvm_unreachable( "null bufferizedFunctionTypes when bufferizing CallOpInterface"); - return bufferize(b, callOp, bvm, aliasInfo, allocationFns, - *bufferizedFunctionTypes); + return bufferize(b, callOp, state, *bufferizedFunctionTypes); } // Skip BufferCast and TensorLoad ops. @@ -1251,7 +1202,7 @@ // Bufferize using `BufferizableOpInterface`. if (auto bufferizableOp = dyn_cast(op)) - return bufferizableOp.bufferize(b, bvm, aliasInfo, allocationFns); + return bufferizableOp.bufferize(b, state); // Other op with tensors. No bufferization method specified. auto isaTensor = [](Type t) { return t.isa(); }; @@ -1262,23 +1213,21 @@ } static LogicalResult bufferizeFuncOpInternals( - FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFns, + FuncOp funcOp, BufferizationState &state, DenseMap &bufferizedFunctionTypes) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); // Start by bufferizing `funcOp` arguments. - if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns))) + if (failed(bufferize(b, funcOp, state))) return failure(); // Cannot erase ops during the traversal. Do that afterwards. SmallVector toErase; auto walkFunc = [&](Operation *op) -> WalkResult { - if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns, - &bufferizedFunctionTypes))) + if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes))) return failure(); // Register post-walk erasure, if necessary. @@ -1852,9 +1801,10 @@ // Bufferization phase. if (!options.testAnalysisOnly) { BlockAndValueMapping tensorToBufferMap; - if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo, - *options.allocationFns, - bufferizedFunctionTypes))) + BufferizationState state(aliasInfo, *options.allocationFns, + tensorToBufferMap); + if (failed( + bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes))) return failure(); } } @@ -1926,9 +1876,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto constantOp = cast(op); if (!isaTensor(constantOp.getResult().getType())) return success(); @@ -1948,8 +1896,8 @@ auto globalMemref = globalCreator.getGlobalFor(constantOp); Value memref = b.create( constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); - aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); - map(bvm, constantOp, memref); + state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); + state.mapBuffer(constantOp, memref); return success(); } @@ -1969,10 +1917,10 @@ /// Helper function for LinalgOp bufferization. /// When allocating a new buffer, analyze whether `op` wants to read form that /// buffer. Only in that case, a copy of the result buffer may be needed. -static LogicalResult allocateBuffersForResults( - OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) { +static LogicalResult +allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl &resultBuffers, + BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -1983,24 +1931,21 @@ OpResult opResult = cast(op.getOperation()) .getAliasingOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns); + Value resultBuffer = getResultBuffer(b, opResult, state); if (!resultBuffer) return failure(); resultBuffers.push_back(resultBuffer); } if (op->getNumResults()) - map(bvm, op->getResults(), resultBuffers); + state.mapBuffer(op->getResults(), resultBuffers); return success(); } /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFns) { + BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -2017,13 +1962,11 @@ newInputBuffers.push_back(opOperand->get()); continue; } - newInputBuffers.push_back(lookup(bvm, opOperand->get())); - assert(newInputBuffers.back() && "missing buffer"); + newInputBuffers.push_back(state.lookupBuffer(opOperand->get())); } SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. - if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, - aliasInfo, allocationFns))) + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state))) return failure(); // Clone the newly bufferized op. @@ -2036,7 +1979,7 @@ // Replace the results of the old op with the new output buffers. if (op->getNumResults()) - map(bvm, op->getResults(), newOutputBuffers); + state.mapBuffer(op->getResults(), newOutputBuffers); // The original op will be DCE'd away later. @@ -2087,11 +2030,8 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { - return bufferizeLinalgOp(b, cast(op), bvm, aliasInfo, - allocationFn); + BufferizationState &state) const { + return bufferizeLinalgOp(b, cast(op), state); } }; @@ -2109,9 +2049,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto initTensorOp = cast(op); // The InitTensorOp may have been eliminated. @@ -2123,9 +2061,8 @@ b.setInsertionPoint(initTensorOp); Value alloc = createNewAllocDeallocPairForShapedValue( - b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo, - allocationFn); - map(bvm, initTensorOp.result(), alloc); + b, initTensorOp->getLoc(), initTensorOp.result(), state); + state.mapBuffer(initTensorOp.result(), alloc); return success(); } }; @@ -2178,9 +2115,7 @@ bool isAllocationHoistingBarrier(Operation *op) const { return true; } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto tiledLoopOp = cast(op); // Take a guard before anything else. @@ -2222,15 +2157,14 @@ const OpResult &opResult = tiledLoopOp->getResult(resultIndex); OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); + Value resultBuffer = getResultBuffer(b, opResult, state); if (!resultBuffer) return failure(); // Insert mapping and aliasing info. - aliasInfo.createAliasInfoEntry(resultBuffer); - aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); - map(bvm, opResult, resultBuffer); + state.aliasInfo.createAliasInfoEntry(resultBuffer); + state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); + state.mapBuffer(opResult, resultBuffer); // Insert new operand and bbArg. tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer); @@ -2238,9 +2172,10 @@ body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); // Insert mapping and aliasing info. - aliasInfo.createAliasInfoEntry(newBufferBBArg); - aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); - map(bvm, oldTensorBBArg, newBufferBBArg); + state.aliasInfo.createAliasInfoEntry(newBufferBBArg); + state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, + newBufferBBArg); + state.mapBuffer(oldTensorBBArg, newBufferBBArg); // Set operand of `linalg.yield` to the bbArg so it just canonicalizes // away later. @@ -2268,8 +2203,7 @@ continue; } - Value inputBuffer = lookup(bvm, oldInputTensor); - assert(inputBuffer && " missing buffer for operand"); + Value inputBuffer = state.lookupBuffer(oldInputTensor); // Insert new operand and bbArg. tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer); @@ -2278,9 +2212,10 @@ BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); // Insert mapping and aliasing info. - aliasInfo.createAliasInfoEntry(newBufferBBArg); - aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); - map(bvm, oldTensorBBArg, newBufferBBArg); + state.aliasInfo.createAliasInfoEntry(newBufferBBArg); + state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, + newBufferBBArg); + state.mapBuffer(oldTensorBBArg, newBufferBBArg); // Increment indices. ++numNewInputBuffers; @@ -2318,9 +2253,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto yieldOp = cast(op); // Take a guard before anything else. @@ -2394,9 +2327,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { // scf::IfOp is bufferized after scf::YieldOp in the else branch. return success(); } @@ -2405,9 +2336,7 @@ /// Bufferize the scf::IfOp. This function is called after the YieldOp was /// bufferized. static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { + BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(ifOp); @@ -2420,13 +2349,12 @@ assert(opResult.getType().isa() && "unsupported unranked tensor"); - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); + Value resultBuffer = getResultBuffer(b, opResult, state); if (!resultBuffer) return failure(); - aliasInfo.createAliasInfoEntry(resultBuffer); - map(bvm, opResult, resultBuffer); + state.aliasInfo.createAliasInfoEntry(resultBuffer); + state.mapBuffer(opResult, resultBuffer); } return success(); @@ -2477,9 +2405,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { // Note: This method is just setting up the mappings for the block arguments // and the result buffer. The op is bufferized after the scf::YieldOp. @@ -2497,17 +2423,16 @@ "unsupported unranked tensor"); // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); + Value resultBuffer = getResultBuffer(b, opResult, state); if (!resultBuffer) return failure(); OpOperand &opOperand = forOp.getOpOperandForResult(opResult); BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); - aliasInfo.createAliasInfoEntry(resultBuffer); - aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); - map(bvm, bbArg, resultBuffer); - map(bvm, opResult, resultBuffer); + state.aliasInfo.createAliasInfoEntry(resultBuffer); + state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); + state.mapBuffer(bbArg, resultBuffer); + state.mapBuffer(opResult, resultBuffer); } return success(); @@ -2517,9 +2442,7 @@ /// Bufferize the scf::ForOp. This function is called after the YieldOp was /// bufferized. static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { + BufferizationState &state) { auto yieldOp = cast(&forOp.region().front().back()); for (OpOperand &operand : yieldOp->getOpOperands()) { auto tensorType = operand.get().getType().dyn_cast(); @@ -2529,9 +2452,10 @@ OpOperand &forOperand = forOp.getOpOperandForResult( forOp->getResult(operand.getOperandNumber())); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - Value yieldedBuffer = lookup(bvm, operand.get()); - Value bbArgBuffer = lookup(bvm, bbArg); - if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) { + Value yieldedBuffer = state.lookupBuffer(operand.get()); + Value bbArgBuffer = state.lookupBuffer(bbArg); + if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, + bbArgBuffer)) { // TODO: this could get resolved with copies but it can also turn into // swaps so we need to be careful about order of copies. return yieldOp->emitError() @@ -2567,9 +2491,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto yieldOp = cast(op); if (auto execOp = dyn_cast(yieldOp->getParentOp())) { @@ -2584,12 +2506,12 @@ if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { if (ifOp.elseYield() != yieldOp) return success(); - return bufferizeIfOp(ifOp, b, bvm, aliasInfo, allocationFn); + return bufferizeIfOp(ifOp, b, state); } // Bufferize scf::ForOp after bufferizing the scf::YieldOp. if (auto forOp = dyn_cast(yieldOp->getParentOp())) - return bufferizeForOp(forOp, b, bvm, aliasInfo, allocationFn); + return bufferizeForOp(forOp, b, state); return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp"); } @@ -2635,9 +2557,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { llvm_unreachable("CallOps are handled separately"); return failure(); } @@ -2659,9 +2579,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto returnOp = cast(op); // Take a guard before anything else. @@ -2675,12 +2593,11 @@ auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) continue; - Value v = lookup(bvm, operand.get()); - assert(v && "missing buffer for result"); + Value v = state.lookupBuffer(operand.get()); Value returnTensor = b.create(returnOp.getLoc(), v); operand.set(returnTensor); - aliasInfo.insertNewBufferEquivalence(returnTensor, v); - map(bvm, returnTensor, v); + state.aliasInfo.insertNewBufferEquivalence(returnTensor, v); + state.mapBuffer(returnTensor, v); } return success(); } @@ -2715,17 +2632,14 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto castOp = cast(op); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(castOp); - Value resultBuffer = - getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn); + Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state); if (!resultBuffer) return failure(); Type sourceType = resultBuffer.getType(); @@ -2744,8 +2658,8 @@ castOp.getResult().getType(), layout, memorySpace); Value res = b.create(castOp.getLoc(), memRefType, resultBuffer); - aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); - map(bvm, castOp.getResult(), res); + state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); + state.mapBuffer(castOp.getResult(), res); return success(); } }; @@ -2766,9 +2680,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto dimOp = cast(op); // Take a guard before anything else. @@ -2776,8 +2688,7 @@ b.setInsertionPoint(dimOp); if (dimOp.source().getType().isa()) { - Value v = lookup(bvm, dimOp.source()); - assert(v && "missing buffer"); + Value v = state.lookupBuffer(dimOp.source()); dimOp.result().replaceAllUsesWith( b.create(dimOp.getLoc(), v, dimOp.index())); } @@ -2812,9 +2723,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto extractSliceOp = cast(op); // Take a guard before anything else. @@ -2824,18 +2733,16 @@ Location loc = extractSliceOp.getLoc(); // Bail if source was not bufferized. - Value srcMemref = lookup(bvm, extractSliceOp.source()); - if (!srcMemref) - return failure(); + Value srcMemref = state.lookupBuffer(extractSliceOp.source()); auto srcMemrefType = srcMemref.getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); // If not inplaceable, alloc. Value alloc; - if (!aliasInfo.isInPlace(extractSliceOp->getResult(0))) + if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0))) alloc = createNewAllocDeallocPairForShapedValue( - b, loc, extractSliceOp.result(), aliasInfo, allocationFn); + b, loc, extractSliceOp.result(), state); // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(extractSliceOp); @@ -2851,17 +2758,18 @@ loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); // Insert new alias. - aliasInfo.insertNewBufferAlias(subView, srcMemref); + state.aliasInfo.insertNewBufferAlias(subView, srcMemref); /// If not inplaceable, copy. if (alloc) { // Do not copy if the copied data is never read. if (isValueRead(extractSliceOp.result())) - allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc); + state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView, + alloc); subView = alloc; } - map(bvm, extractSliceOp.result(), subView); + state.mapBuffer(extractSliceOp.result(), subView); return success(); } }; @@ -2882,9 +2790,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto extractOp = cast(op); // Take a guard before anything else. @@ -2892,7 +2798,7 @@ b.setInsertionPoint(extractOp); Location loc = extractOp.getLoc(); - Value srcMemref = lookup(bvm, extractOp.tensor()); + Value srcMemref = state.lookupBuffer(extractOp.tensor()); Value l = b.create(loc, srcMemref, extractOp.indices()); extractOp.replaceAllUsesWith(l); return success(); @@ -2950,9 +2856,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto insertSliceOp = cast(op); // Take a guard before anything else. @@ -2969,15 +2873,12 @@ // TODO: be very loud about it or even consider failing the pass. // Alloc a copy for `insertSliceOp.dest()`, it will become the result // buffer. - Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm, - aliasInfo, allocationFn); + Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state); if (!dstMemref) return failure(); auto dstMemrefType = dstMemref.getType().cast(); - Value srcMemref = lookup(bvm, insertSliceOp.source()); - if (!srcMemref) - return failure(); + Value srcMemref = state.lookupBuffer(insertSliceOp.source()); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getRank(), dstMemrefType, @@ -2991,9 +2892,9 @@ // - The result is not inplace. This is the case where the whole tensor is // cloned and the clone needs to be updated. // TODO: Is this necessary? - if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo, + if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo, insertSliceOp) || - !aliasInfo.isInPlace(insertSliceOp->getResult(0))) { + !state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) { LDBG("insert_slice needs extra source copy: " << insertSliceOp.source() << " -> copy\n"); // Take a subview of the dst. @@ -3001,11 +2902,12 @@ loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); // Insert new alias. - aliasInfo.insertNewBufferAlias(subView, dstMemref); - allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView); + state.aliasInfo.insertNewBufferAlias(subView, dstMemref); + state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, + subView); } - map(bvm, insertSliceOp.result(), dstMemref); + state.mapBuffer(insertSliceOp.result(), dstMemref); return success(); } @@ -3035,9 +2937,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto transferReadOp = cast(op); // Take a guard before anything else. @@ -3048,8 +2948,7 @@ return failure(); // TransferReadOp always reads from the bufferized op.source(). - Value v = lookup(bvm, transferReadOp.source()); - assert(v && "missing buffer"); + Value v = state.lookupBuffer(transferReadOp.source()); transferReadOp.sourceMutable().assign(v); return success(); } @@ -3086,9 +2985,7 @@ } LogicalResult bufferize(Operation *op, OpBuilder &b, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) const { + BufferizationState &state) const { auto writeOp = cast(op); // Take a guard before anything else. @@ -3101,15 +2998,14 @@ // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. - Value resultBuffer = - getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn); + Value resultBuffer = getResultBuffer(b, op->getResult(0), state); if (!resultBuffer) return failure(); b.create( writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_map(), writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); - map(bvm, op->getResult(0), resultBuffer); + state.mapBuffer(op->getResult(0), resultBuffer); return success(); }