diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -193,6 +193,8 @@ AllocationCallbacks allocationFns, DenseMap *bufferizedFunctionTypes = nullptr); +/// Register external models implemented for the `BufferizableOpInterface`. +void registerBufferiableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -296,7 +296,7 @@ /// result can be buferized inPlace. /// If no InPlaceSpec attribute has been set for `opResult`, return /// InPlaceSpec::None. -static InPlaceSpec getInPlace(OpResult opResult) { +LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(OpResult opResult) { if (!opResult) return InPlaceSpec::None; @@ -355,7 +355,7 @@ LinalgDialect::kInplaceableAttrName); } -LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(Value v) { +static InPlaceSpec getInPlace(Value v) { if (auto bbArg = v.dyn_cast()) return getInPlace(bbArg); return getInPlace(v.cast()); @@ -414,227 +414,39 @@ } //===----------------------------------------------------------------------===// -// Op-specific semantics helper to retrieve matching inplaceable result. -// These should become proper interfaces interfaces when the time is right. -// Modulo better naming, these helpers / interfaces comprise information on: -// 1. Whether an op has a known bufferization behavior (i.e. an instance of -// BufferizableOpInterface). -// 2. Whether an op, when bufferized inplace, can guarantee an -// (OpOperand, OpResult) pair bufferizes to equivalent (i.e. the same) -// buffers in memory. -// 3. Whether an op operand, when bufferized inplace, aliases a return value. -// 4. Whether an op return value, when bufferized inplace, aliases an operand. -// 5. Whether an op bufferizes to a memory read. -// 6. Whether an op bufferizes to a memory write. -// 7. The buffer relationship between an operand and it corresponding result -// (in case of in-place bufferization). -// These interfaces are necessary to distinguish between various cases and allow -// special inplace behavior for (ExtractSliceOp, InsertSliceOp) pairs. +// Helper functions for BufferizableOpInterface //===----------------------------------------------------------------------===// -/// Return `true` if the op is explicitly supported by bufferization or if it -/// has no result tensors. -/// Other cases must be conservative. -static bool hasKnownBufferizationAliasingBehavior(Operation *op) { - return - // clang-format off - isa(op) - // clang-format on - || (none_of(op->getResultTypes(), isaTensor) && - none_of(op->getOperandTypes(), isaTensor)); -} - -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// Return null if no such result exists. -static OpResult getInplaceableOpResult(TiledLoopOp op, OpOperand &opOperand) { - return op.getTiedOpResult(opOperand); -} - -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// Return null if no such result exists. -static OpResult getInplaceableOpResult(scf::ForOp forOp, OpOperand &opOperand) { - if (!opOperand.get().getType().isa()) - return OpResult(); - return forOp.getResultForOpOperand(opOperand); -} - -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// Return null if no such result exists. -static OpResult getInplaceableOpResult(LinalgOp linalgOp, - OpOperand &opOperand) { - if (!opOperand.get().getType().isa()) - return OpResult(); - // For now assume inputs are never inplaceable. - // TODO: refine this. - if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) - return OpResult(); - int64_t outputOperandIndex = - opOperand.getOperandNumber() - linalgOp.getNumInputs(); - int64_t numOutputBuffers = 0; - for (unsigned idx = 0; idx < outputOperandIndex; ++idx) - if (!linalgOp.getOutputOperand(idx)->get().getType().isa()) - ++numOutputBuffers; - return linalgOp->getResult(outputOperandIndex - numOutputBuffers); -} - -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// Return null if no such result exists. -static OpResult getInplaceableOpResult(VectorTransferOpInterface op, - OpOperand &opOperand) { - if (opOperand.get() != op.source() || - !op.source().getType().isa() || - isa(op)) - return OpResult(); - return op->getResult(0); -} - -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// Return null if no such result exists. -static OpResult getInplaceableOpResult(InsertSliceOp op, OpOperand &opOperand) { - if (&opOperand != &op->getOpOperand(1) /*dest*/) - return OpResult(); - return op->getResult(0); -} - -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// The inplace analysis uses this information along with interfering read -/// analysis to determine which op results reuse the same buffer as some -/// operand. -static OpResult getInplaceableOpResult(OpOperand &opOperand) { - return TypeSwitch(opOperand.getOwner()) - // clang-format off - // Ops that perform destructive updates on operand(s) to produce - // result(s). - .Case( - [&](auto op) { return getInplaceableOpResult(op, opOperand); }) - // Some ops just return an alias to an operand when bufferized inplace. - // Such OpResults are never inplaceable on an OpOperand. - .Case( - [] (auto op) { return OpResult(); }) - // CallOpInterface is special, it needs to wait for the callee to be - // bufferized and needs to inspect the BufferAliasInfo object. It can't - // make a proper determination by itself and needs to be conservative. - .Case([&](CallOpInterface op) { return OpResult(); }) - // Other ops. - .Default([&](Operation *op) { return OpResult(); }); - // clang-format on -} - -/// Either one of the corresponding yield values from the then/else branches -/// may alias with the result. -static void populateAliasingOpOperands(scf::IfOp op, OpResult result, - SmallVector &operands) { - size_t resultNum = std::distance(op->getOpResults().begin(), - llvm::find(op->getOpResults(), result)); - operands.push_back(&op.thenYield()->getOpOperand(resultNum)); - operands.push_back(&op.elseYield()->getOpOperand(resultNum)); -} - /// Determine which OpOperand* will alias with `result` if the op is bufferized -/// in place. Note that multiple OpOperands can may potentially alias with an -/// OpResult. E.g.: std.select in the future. +/// in place. Return an empty vector if the op is not bufferizable. static SmallVector getAliasingOpOperand(OpResult result) { - SmallVector r; - // Unknown ops are handled conservatively and never bufferize in-place. - if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp())) - return SmallVector(); - TypeSwitch(result.getDefiningOp()) - .Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); }) - .Case( - [&](auto op) { r.push_back(&op->getOpOperand(0)); }) - // In the case of scf::ForOp, this currently assumes the iter_args / yield - // are 1-1. This may fail and is verified at the end. - // TODO: update this. - .Case([&](scf::ForOp op) { - r.push_back(&op.getIterOpOperands()[result.getResultNumber()]); - }) - .Case([&](InsertSliceOp op) { r.push_back(&op->getOpOperand(1)); }) - .Case([&](LinalgOp op) { - r.push_back(op.getOutputTensorOperands()[result.getResultNumber()]); - }) - .Case([&](TiledLoopOp op) { - // TODO: TiledLoopOp helper method to avoid leaking impl details. - r.push_back(&op->getOpOperand(op.getNumControlOperands() + - op.getNumInputs() + - result.getResultNumber())); - }) - .Case([&](vector::TransferWriteOp op) { - r.push_back(&op->getOpOperand(1)); - }) - .Case( - [&](auto op) {}) - .Default([&](Operation *op) { - op->dump(); - llvm_unreachable("unexpected defining op"); - }); - return r; -} - -/// If the an ExtractSliceOp is bufferized in-place, the source operand will -/// alias with the result. -static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) { - if (&op->getOpOperand(0) == &opOperand) - return op->getResult(0); - return OpResult(); -} - -/// If the a tensor::CastOp is bufferized in-place, the source operand will -/// alias with the result. -static OpResult getAliasingOpResult(tensor::CastOp op, OpOperand &opOperand) { - if (&op->getOpOperand(0) == &opOperand) - return op->getResult(0); - return OpResult(); + if (Operation *op = result.getDefiningOp()) + if (auto bufferizableOp = dyn_cast(op)) + return bufferizableOp.getAliasingOpOperand(result); + return {}; } /// Determine which OpResult will alias with `opOperand` if the op is bufferized -/// in place. This is a superset of `getInplaceableOpResult`. -/// TODO: in the future this may need to evolve towards a list of OpResult. +/// in place. This is a superset of `getInplaceableOpResult`. Return an empty +/// OpResult if the op is not bufferizable. static OpResult getAliasingOpResult(OpOperand &opOperand) { - return TypeSwitch(opOperand.getOwner()) - // Some ops are different: Their result is not inplaceable on an OpOperand - // but when bufferized inplace, their result is aliasing (a subregion of) - // an OpOperand. - .Case( - [&](auto op) { return getAliasingOpResult(op, opOperand); }) - // All other ops, return the result of `getInplaceableOpResult`. - .Default( - [&](Operation *op) { return getInplaceableOpResult(opOperand); }); + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.getAliasingOpResult(opOperand); + return OpResult(); } /// Return `true` if the given OpOperand does not bufferize to a memory read or -/// write, but creates an alias when bufferized inplace. +/// write, but creates an alias when bufferized inplace. Return `false` if the +/// op is not bufferizable. static bool bufferizesToAliasOnly(OpOperand &opOperand) { - Operation *owner = opOperand.getOwner(); - // TODO: In the future this may need to evolve into a TypeSwitch. For all - // currently supported ops, the aliasing-only OpOperand is always the first - // one. - return isa(owner) && - &opOperand == &owner->getOpOperand(0); + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToAliasOnly(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return false. + return false; } // Predeclaration of function. @@ -661,72 +473,40 @@ return false; } -/// Return true if `opOperand` bufferizes to a memory read. +/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the +/// op is not bufferizable. static bool bufferizesToMemoryRead(OpOperand &opOperand) { - // Unknown op that returns a tensor. The inplace analysis does not support - // it. Conservatively return true. - if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) - return true; - // Some ops alone do not bufferize to a memory read, but one of their uses - // may. - if (bufferizesToAliasOnly(opOperand)) - return false; - // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its - // matching bbArg may. - if (auto forOp = dyn_cast(opOperand.getOwner())) - return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); - // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its - // matching bbArg may. - if (auto tiledLoopOp = dyn_cast(opOperand.getOwner())) - return isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); - // CallOpInterface alone doesn't bufferize to a memory read, one of the uses - // of the matching bbArg may. It is the responsibility of the caller to - // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be - // conservative. - if (auto callOp = dyn_cast(opOperand.getOwner())) - return true; - if (auto linalgOp = dyn_cast(opOperand.getOwner())) - return linalgOp.isInputTensor(&opOperand) || - linalgOp.isInitTensor(&opOperand); - // All other cases are considered to bufferize to memory reads. - // In particular, terminators are often the last use and need to be considered - // as reads to return the proper value and avoid WAW clobbers. + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToMemoryRead(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return true. return true; } -/// Return true if `opOperand` bufferizes to a memory write. +/// Return true if `opOperand` bufferizes to a memory write. Return +/// `true` if the op is not bufferizable. static bool bufferizesToMemoryWrite(OpOperand &opOperand) { - // These terminators are not writes. - if (isa(opOperand.getOwner())) - return false; - // Some ops alone do not bufferize to a memory write, but one of their uses - // may. - if (bufferizesToAliasOnly(opOperand)) - return false; - // CallOpInterface alone doesn't bufferize to a memory write, one of the uses - // of the matching bbArg may. It is the responsibility of the caller to - // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be - // conservative. - if (auto callOp = dyn_cast(opOperand.getOwner())) - return true; - // Unknown op that returns a tensor. The inplace analysis does not support - // it. Conservatively return true. - if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) - return true; - OpResult opResult = getAliasingOpResult(opOperand); - // Only supported op with a matching result for opOperand bufferize to a - // write. E.g., ReturnOp does not bufferize to a write. - return static_cast(opResult); + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToMemoryWrite(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return true. + return true; } -/// Returns the relationship between the operand and the its corresponding -/// OpResult that it may alias with. -static BufferRelation bufferRelation(OpOperand &operand) { - return TypeSwitch(operand.getOwner()) - // ExtractSliceOp returns a subview of the original tensor. - .Case([&](ExtractSliceOp op) { return BufferRelation::None; }) - // All other ops: Buffers are equivalent. - .Default([&](Operation *op) { return BufferRelation::Equivalent; }); +/// Return the relationship between the operand and the its corresponding +/// OpResult that it may alias with. Return None if the op is not bufferizable. +static BufferRelation bufferRelation(OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferRelation(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return None. + return BufferRelation::None; } //===----------------------------------------------------------------------===// @@ -814,7 +594,7 @@ if (Operation *op = v.getDefiningOp()) { if (isa(op) || - !hasKnownBufferizationAliasingBehavior(op)) { + !dyn_cast(op)) { LDBG("-----------notWritable op\n"); return true; } @@ -934,13 +714,14 @@ Operation *op = value.getDefiningOp(); if (!op) return true; - if (!hasKnownBufferizationAliasingBehavior(op)) + auto bufferizableOp = dyn_cast(op); + if (!bufferizableOp) return true; if (isa(op)) return true; SmallVector opOperands = - getAliasingOpOperand(value.cast()); + bufferizableOp.getAliasingOpOperand(value.cast()); assert(opOperands.size() <= 1 && "op with multiple aliasing OpOperands not expected"); @@ -1507,83 +1288,6 @@ return operandBuffer; } -/// Helper function for LinalgOp bufferization. -/// When allocating a new buffer, analyze whether `op` wants to read form that -/// buffer. Only in that case, a copy of the result buffer may be needed. -static LogicalResult allocateBuffersForResults( - OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // TODO: provide the proper interface to iterate on OpResults and get the - // matching OpOperands. - for (OpOperand *opOperand : op.getOutputOperands()) { - OpResult opResult = getInplaceableOpResult(*opOperand); - assert(opResult && "could not find correspond OpResult"); - bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy); - if (!resultBuffer) - return failure(); - resultBuffers.push_back(resultBuffer); - } - - if (op->getNumResults()) - map(bvm, op->getResults(), resultBuffers); - - return success(); -} - -/// Generic conversion for any LinalgOp on tensors. -static LogicalResult bufferize(OpBuilder &b, LinalgOp op, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFns) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need - // basis. - if (!op.hasTensorSemantics()) - return op->emitError() << "op does not have tensor semantics"; - - Location loc = op.getLoc(); - SmallVector newInputBuffers; - newInputBuffers.reserve(op.getNumInputs()); - for (OpOperand *opOperand : op.getInputOperands()) { - if (op.isScalar(opOperand)) { - newInputBuffers.push_back(opOperand->get()); - continue; - } - newInputBuffers.push_back(lookup(bvm, opOperand->get())); - assert(newInputBuffers.back() && "missing buffer"); - } - SmallVector newOutputBuffers; - // Try to allocate new buffers depending on op's inplace semantics. - if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, - aliasInfo, allocationFns))) - return failure(); - - // Clone the newly bufferized op. - SmallVector newOperands = newInputBuffers; - newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); - - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(op); - op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); - - // Replace the results of the old op with the new output buffers. - if (op->getNumResults()) - map(bvm, op->getResults(), newOutputBuffers); - - // The original op will be DCE'd away later. - - return success(); -} - /// In a first approximation, all the function arguments of a FuncOp are marked /// inplaceable. For now, it is the responsibility of the `callOp` bufferization /// to allow FuncOp that are inplaceable to write inPlace. @@ -1725,142 +1429,6 @@ return success(); } -/// tensor::CastOp bufferizes to memref::CastOp. -static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(castOp); - - Value resultBuffer = - getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn); - if (!resultBuffer) - return failure(); - Type sourceType = resultBuffer.getType(); - auto rankedMemRefType = sourceType.dyn_cast(); - auto unrankedMemRefType = sourceType.dyn_cast(); - assert(rankedMemRefType || unrankedMemRefType); - Attribute memorySpace = rankedMemRefType - ? rankedMemRefType.getMemorySpace() - : unrankedMemRefType.getMemorySpace(); - TensorType tensorType = castOp.getResult().getType().cast(); - MemRefLayoutAttrInterface layout = - rankedMemRefType && tensorType.isa() - ? rankedMemRefType.getLayout() - : MemRefLayoutAttrInterface(); - Type memRefType = getContiguousOrUnrankedMemRefType( - castOp.getResult().getType(), layout, memorySpace); - Value res = - b.create(castOp.getLoc(), memRefType, resultBuffer); - aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); - map(bvm, castOp.getResult(), res); - return success(); -} - -static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { - assert(constantOp.getType().dyn_cast() && - "not a constant ranked tensor"); - auto moduleOp = constantOp->getParentOfType(); - if (!moduleOp) { - return constantOp.emitError( - "cannot bufferize constants not within builtin.module op"); - } - GlobalCreator globalCreator(moduleOp); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(constantOp); - - auto globalMemref = globalCreator.getGlobalFor(constantOp); - Value memref = b.create( - constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); - aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); - map(bvm, constantOp, memref); - - return success(); -} - -/// DimOp tensor operand is modified inplace. This allows leaving dead -/// tensors behind that will get DCE'd. -static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(dimOp); - - if (dimOp.source().getType().isa()) { - Value v = lookup(bvm, dimOp.source()); - assert(v && "missing buffer"); - dimOp.result().replaceAllUsesWith( - b.create(dimOp.getLoc(), v, dimOp.index())); - } - return success(); -} - -static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - for (OpResult opResult : forOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); - - // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); - if (!resultBuffer) - return failure(); - - OpOperand &opOperand = forOp.getOpOperandForResult(opResult); - BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); - aliasInfo.createAliasInfoEntry(resultBuffer); - aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); - map(bvm, bbArg, resultBuffer); - map(bvm, opResult, resultBuffer); - } - - return success(); -} - -static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - for (OpResult opResult : ifOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); - - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); - if (!resultBuffer) - return failure(); - - aliasInfo.createAliasInfoEntry(resultBuffer); - map(bvm, opResult, resultBuffer); - } - - return success(); -} - /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BlockAndValueMapping &bvm, @@ -1887,444 +1455,40 @@ return success(); } -/// InitTensor always allocates (unless it was eliminated). -/// TODO: consider hoisting across function boundaries prior to bufferization. -static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // The InitTensorOp may have been eliminated. - if (initTensorOp->getUses().empty()) - return success(); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(initTensorOp); - - Value alloc = createNewAllocDeallocPairForShapedValue( - b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo, - allocationFn); - map(bvm, initTensorOp.result(), alloc); - return success(); -} - -/// ReturnOp always creates memref::TensorLoadOp. -static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot insert after returnOp. - b.setInsertionPoint(returnOp); - - assert(isa(returnOp->getParentOp()) && - "only support FuncOp parent for ReturnOp"); - for (OpOperand &operand : returnOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - Value v = lookup(bvm, operand.get()); - assert(v && "missing buffer for result"); - Value returnTensor = b.create(returnOp.getLoc(), v); - operand.set(returnTensor); - aliasInfo.insertNewBufferEquivalence(returnTensor, v); - map(bvm, returnTensor, v); - } - return success(); -} - -/// Bufferization for TiledLoopOp.. -static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - // Allocate output buffers if needed, forward output tensor args to the - // terminator. - Operation *yieldOp = tiledLoopOp.getBody()->getTerminator(); - Block *body = tiledLoopOp.getBody(); - - // Take copies of the old input and output operands, so we can insert inplace - // easily. - auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs()); - auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs()); - - int numLoops = tiledLoopOp.getNumLoops(); - int numControlOperands = tiledLoopOp.getNumControlOperands(); - - // Add buffers for outputs and the corresponding block arguments. - // Keep separate iterators to increment without further leaking impl. details. - // Start with outputs to avoid interference from new input buffers. - int numNewOutputBuffers = 0; - int resultIndex = 0; - int oldOutputBBArgIndex = numLoops + oldInputs.size(); - int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size(); - int nextOutputOperandIndex = - numControlOperands + oldInputs.size() + oldOutputs.size(); - for (Value oldOutputTensor : oldOutputs) { - if (!oldOutputTensor.getType().isa()) { - // Skip and increment the old bbarg index only. - ++oldOutputBBArgIndex; - // Do not increment resultIndex as only tensors are returned. - // TODO: better interface to avoid leaking such impl details. - continue; - } - - assert(oldOutputTensor.getType().isa() && - "bufferizable output must be a ranked tensor"); - - const OpResult &opResult = tiledLoopOp->getResult(resultIndex); - OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); - if (!resultBuffer) - return failure(); - - // Insert mapping and aliasing info. - aliasInfo.createAliasInfoEntry(resultBuffer); - aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); - map(bvm, opResult, resultBuffer); - - // Insert new operand and bbArg. - tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer); - BlockArgument newBufferBBArg = - body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); - BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); - // Insert mapping and aliasing info. - aliasInfo.createAliasInfoEntry(newBufferBBArg); - aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); - map(bvm, oldTensorBBArg, newBufferBBArg); - - // Set operand of `linalg.yield` to the bbArg so it just canonicalizes away - // later. - yieldOperand.set(oldTensorBBArg); - - // Increment indices. - ++numNewOutputBuffers; - ++resultIndex; - ++oldOutputBBArgIndex; - ++nextOutputBBArgIndex; - ++nextOutputOperandIndex; - } - - // Add buffers for inputs and the corresponding block arguments. - // Keep separate iterators to increment without further leaking impl. details. - int numNewInputBuffers = 0; - int oldInputBBArgIndex = numLoops; - int nextInputBBArgIndex = numLoops + oldInputs.size(); - int nextInputOperandIndex = numControlOperands + oldInputs.size(); - for (Value oldInputTensor : oldInputs) { - if (!oldInputTensor.getType().isa()) { - // Skip and increment the old bbarg index only. - ++oldInputBBArgIndex; - continue; - } +//===----------------------------------------------------------------------===// +// Bufferization analyses. +//===----------------------------------------------------------------------===// - Value inputBuffer = lookup(bvm, oldInputTensor); - assert(inputBuffer && " missing buffer for operand"); +/// Determine if `operand` can be bufferized in-place with `result`. If so, set +/// InPlaceSpec::True on the result. Otherwise, set InPlaceSpec::False on the +/// result. +static LogicalResult +bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { +#ifndef NDEBUG + SmallVector opOperands = getAliasingOpOperand(result); + assert(llvm::find(opOperands, &operand) != opOperands.end() && + "operand and result do not match"); +#endif // NDEBUG - // Insert new operand and bbArg. - tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer); - BlockArgument newBufferBBArg = - body->insertArgument(nextInputBBArgIndex, inputBuffer.getType()); - BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); + int64_t resultNumber = result.getResultNumber(); + (void)resultNumber; + LDBG('\n'); + LDBG("Inplace analysis for <- #" << resultNumber << " -> #" + << operand.getOperandNumber() << " in " + << printValueInfo(result) << '\n'); - // Insert mapping and aliasing info. - aliasInfo.createAliasInfoEntry(newBufferBBArg); - aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); - map(bvm, oldTensorBBArg, newBufferBBArg); + bool foundInterference = + aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) || + aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo); - // Increment indices. - ++numNewInputBuffers; - ++oldInputBBArgIndex; - ++nextInputBBArgIndex; - ++nextInputOperandIndex; - } + if (foundInterference) + aliasInfo.bufferizeOutOfPlace(result); + else + aliasInfo.bufferizeInPlace(result, operand); - // Update segment sizes. - // TODO: Helper method to avoid leaking impl details. - tiledLoopOp->setAttr( - TiledLoopOp::getOperandSegmentSizeAttr(), - b.getI32VectorAttr( - {numLoops, numLoops, numLoops, - static_cast(oldInputs.size()) + numNewInputBuffers, - static_cast(oldOutputs.size()) + numNewOutputBuffers})); - - return success(); -} - -/// Bufferize ExtractSliceOp to subview with optional alloc + copy depending on -/// whether or not it is marked inplaceable. -/// Note that `getInplaceableOpResult` on a ExtractSliceOp always returns null. -/// As consequence a ExtractSliceOp always alloc + copy when taken in -/// isolation. -static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - LDBG("bufferize: " << *extractSliceOp << '\n'); - - Location loc = extractSliceOp.getLoc(); - // Bail if source was not bufferized. - Value srcMemref = lookup(bvm, extractSliceOp.source()); - if (!srcMemref) - return failure(); - auto srcMemrefType = srcMemref.getType().cast(); - auto dstTensorType = - extractSliceOp.result().getType().cast(); - - // If not inplaceable, alloc. - Value alloc; - auto inPlace = getInPlace(extractSliceOp->getResult(0)); - if (inPlace != InPlaceSpec::True) - alloc = createNewAllocDeallocPairForShapedValue( - b, loc, extractSliceOp.result(), aliasInfo, allocationFn); - - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(extractSliceOp); - - // Bufferize to subview. - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - dstTensorType.getRank(), srcMemrefType, - extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), - extractSliceOp.getMixedStrides()) - .cast(); - Value subView = b.create( - loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), - extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); - // Insert new alias. - aliasInfo.insertNewBufferAlias(subView, srcMemref); - - /// If not inplaceable, copy. - if (alloc) { - // Do not copy if the copied data is never read. - if (isValueRead(extractSliceOp.result())) - b.create(extractSliceOp.getLoc(), subView, alloc); - subView = alloc; - } - - map(bvm, extractSliceOp.result(), subView); - return success(); -} - -static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(insertSliceOp); - - LDBG("bufferize: " << *insertSliceOp << '\n'); - - Location loc = insertSliceOp.getLoc(); - // Since insert_slice arise from tiling and introducing loops, this - // case is generally a deal breaker. When used with loops, this ends up - // cloning the whole tensor on every single iteration and is a symptom - // of a catastrophically bad scheduling decision. - // TODO: be very loud about it or even consider failing the pass. - // Alloc a copy for `insertSliceOp.dest()`, it will become the result - // buffer. - Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm, - aliasInfo, allocationFn); - if (!dstMemref) - return failure(); - auto dstMemrefType = dstMemref.getType().cast(); - - Value srcMemref = lookup(bvm, insertSliceOp.source()); - if (!srcMemref) - return failure(); - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - insertSliceOp.getSourceType().getRank(), dstMemrefType, - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()) - .cast(); - - // A copy of the source buffer is needed if either: - // - The producer of `source` is not inplace. This is the case where a - // slice is computed out of place into the inplace full tensor. - // - The result is not inplace. This is the case where the whole tensor is - // cloned and the clone needs to be updated. - auto inPlace = getInPlace(insertSliceOp->getResult(0)); - // TODO: Is this necessary? - if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp( - insertSliceOp) || - inPlace != InPlaceSpec::True) { - LDBG("insert_slice needs extra source copy: " << insertSliceOp.source() - << " -> copy\n"); - // Take a subview of the dst. - Value subView = b.create( - loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), - insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); - // Insert new alias. - aliasInfo.insertNewBufferAlias(subView, dstMemref); - b.create(insertSliceOp.getLoc(), srcMemref, subView); - } - - map(bvm, insertSliceOp.result(), dstMemref); - - return success(); -} - -static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFn) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - if (op.getShapedType().isa()) - return failure(); - - /// transfer_read from buffer always reads from the bufferized - /// op.source(). - if (auto readOp = dyn_cast(op.getOperation())) { - Value v = lookup(bvm, op.source()); - assert(v && "missing buffer"); - readOp.sourceMutable().assign(v); - return success(); - } - - // Create a new transfer_write on buffer that doesn't have a return value. - // Leave the previous transfer_write to dead code as it still has uses at - // this point. - auto writeOp = cast(op.getOperation()); - Value resultBuffer = - getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn); - if (!resultBuffer) - return failure(); - b.create( - op.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), - writeOp.permutation_map(), - writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); - map(bvm, op->getResult(0), resultBuffer); - - return success(); -} - -static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot create IR past a yieldOp. - b.setInsertionPoint(yieldOp); - - if (auto execOp = dyn_cast(yieldOp->getParentOp())) { - if (execOp->getNumResults() != 0) - return execOp->emitError( - "expected result-less scf.execute_region containing op"); - return success(); - } - - if (auto ifOp = dyn_cast(yieldOp->getParentOp())) - return success(); - - scf::ForOp forOp = dyn_cast(yieldOp->getParentOp()); - if (!forOp) - return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp"); - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - Value yieldedBuffer = lookup(bvm, operand.get()); - Value bbArgBuffer = lookup(bvm, bbArg); - if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - return yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand"; - } - - // Buffers are equivalent so the work is already done and we just yield the - // bbArg so that it later canonicalizes away. - operand.set(bbArg); - } - return success(); -} - -/// Bufferization for linalg::YieldOp either does not involve tensors or just -/// results in later canonicalization. In either case it does nothing. -static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot create IR past a yieldOp. - b.setInsertionPoint(yieldOp); - - // No tensors -> success. - if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor)) - return success(); - // linalg::YieldOp nested under TiledLoop must just canonicalize. - if (yieldOp->getParentOfType()) - return success(); - llvm_unreachable("unexpected yieldOp"); -} - -/// Bufferization for tensor::ExtractOp just translate to memref.load, it only -/// reads the tensor. -static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractOp); - - Location loc = extractOp.getLoc(); - Value srcMemref = lookup(bvm, extractOp.tensor()); - Value l = b.create(loc, srcMemref, extractOp.indices()); - extractOp.replaceAllUsesWith(l); - return success(); -} -//===----------------------------------------------------------------------===// -// Bufferization analyses. -//===----------------------------------------------------------------------===// - -/// Determine if `operand` can be bufferized in-place with `result`. If so, set -/// InPlaceSpec::True on the result. Otherwise, set InPlaceSpec::False on the -/// result. -static LogicalResult -bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { -#ifndef NDEBUG - SmallVector opOperands = getAliasingOpOperand(result); - assert(llvm::find(opOperands, &operand) != opOperands.end() && - "operand and result do not match"); -#endif // NDEBUG - - int64_t resultNumber = result.getResultNumber(); - (void)resultNumber; - LDBG('\n'); - LDBG("Inplace analysis for <- #" << resultNumber << " -> #" - << operand.getOperandNumber() << " in " - << printValueInfo(result) << '\n'); - - bool foundInterference = - aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) || - aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo); - - if (foundInterference) - aliasInfo.bufferizeOutOfPlace(result); - else - aliasInfo.bufferizeInPlace(result, operand); - - LDBG("Done inplace analysis for result #" << resultNumber << '\n'); + LDBG("Done inplace analysis for result #" << resultNumber << '\n'); return success(); } @@ -2348,7 +1512,9 @@ bufferizableInPlaceAnalysisAliasOnlyOp(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - OpResult result = getAliasingOpResult(operand); + auto bufferizableOp = dyn_cast(operand.getOwner()); + assert(bufferizableOp && "expected op with known bufferization behavior"); + OpResult result = bufferizableOp.getAliasingOpResult(operand); assert(result && "expected that the OpOperand has an aliasing OpResult"); return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); } @@ -2360,10 +1526,12 @@ bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - OpResult result = getInplaceableOpResult(operand); - if (!result) + auto bufferizableOp = dyn_cast(operand.getOwner()); + if (!bufferizableOp) return success(); - return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); + if (OpResult result = bufferizableOp.getInplaceableOpResult(operand)) + return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); + return success(); } /// Analyze the `ops` to determine which OpResults are inplaceable. Walk ops in @@ -2383,17 +1551,18 @@ // Walk ops in reverse for better interference analysis. for (Operation *op : reverse(ops)) { - for (OpOperand &opOperand : op->getOpOperands()) { - if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) - return failure(); - - // Special logic to analyze OpOperands that are not inplaceable on an - // OpResult but may create an alias. - if (bufferizesToAliasOnly(opOperand)) - if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(opOperand, aliasInfo, - domInfo))) + for (OpOperand &opOperand : op->getOpOperands()) + if (opOperand.get().getType().isa()) { + if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - } + + // Special logic to analyze OpOperands that are not inplaceable on an + // OpResult but may create an alias. + if (bufferizesToAliasOnly(opOperand)) + if (failed(bufferizableInPlaceAnalysisAliasOnlyOp( + opOperand, aliasInfo, domInfo))) + return failure(); + } } return success(); @@ -2529,42 +1698,31 @@ AllocationCallbacks allocationFns, DenseMap *bufferizedFunctionTypes) { OpBuilder b(op->getContext()); - return TypeSwitch(op) - // Skip BufferCast and TensorLoad ops. - .Case( - [&](auto) { return success(); }) - .Case( - [&](auto op) { - LDBG("Begin bufferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo, allocationFns); - }) - .Case([&](auto op) { - LDBG("Begin bufferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo); - }) - .Case([&](CallOpInterface op) { - LDBG("Begin bufferize:\n" << op << '\n'); - if (!bufferizedFunctionTypes) - llvm_unreachable( - "null bufferizedFunctionTypes when bufferizing CallOpInterface"); - return bufferize(b, op, bvm, aliasInfo, allocationFns, - *bufferizedFunctionTypes); - }) - .Case([&](arith::ConstantOp op) { - if (!isaTensor(op.getResult().getType())) - return success(); - LDBG("Begin bufferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo); - }) - .Default([&](Operation *op) -> LogicalResult { - auto isaTensor = [](Type t) { return t.isa(); }; - if (any_of(op->getOperandTypes(), isaTensor) || - any_of(op->getResultTypes(), isaTensor)) - return op->emitError() << "unsupported op with tensors"; - return success(); - }); + + // CallOps are handled separately. + if (auto callOp = dyn_cast(op)) { + LDBG("Begin bufferize:\n" << callOp << '\n'); + if (!bufferizedFunctionTypes) + llvm_unreachable( + "null bufferizedFunctionTypes when bufferizing CallOpInterface"); + return bufferize(b, callOp, bvm, aliasInfo, allocationFns, + *bufferizedFunctionTypes); + } + + // Skip BufferCast and TensorLoad ops. + if (isa(op)) + return success(); + + // Bufferize using `BufferizableOpInterface`. + if (auto bufferizableOp = dyn_cast(op)) + return bufferizableOp.bufferize(b, bvm, aliasInfo, allocationFns); + + // Other op with tensors. No bufferization method specified. + auto isaTensor = [](Type t) { return t.isa(); }; + if (any_of(op->getOperandTypes(), isaTensor) || + any_of(op->getResultTypes(), isaTensor)) + return op->emitError() << "unsupported op with tensors"; + return success(); } static LogicalResult bufferizeFuncOpInternals( @@ -2883,7 +2041,11 @@ void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); + registerBufferiableOpInterfaceExternalModels(registry); } private: @@ -3189,3 +2351,1169 @@ std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } + +//===----------------------------------------------------------------------===// +// BufferizableOpInterface Implementations +//===----------------------------------------------------------------------===// + +// TODO: Move these to a different file and BUILD target, so that they are +// decoupled from ComprehensiveBufferize. + +namespace mlir { +namespace linalg { +namespace arith_ext { + +struct ConstantOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {}; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto constantOp = cast(op); + if (!isaTensor(constantOp.getResult().getType())) + return success(); + assert(constantOp.getType().dyn_cast() && + "not a constant ranked tensor"); + auto moduleOp = constantOp->getParentOfType(); + if (!moduleOp) { + return constantOp.emitError( + "cannot bufferize constants not within builtin.module op"); + } + GlobalCreator globalCreator(moduleOp); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(constantOp); + + auto globalMemref = globalCreator.getGlobalFor(constantOp); + Value memref = b.create( + constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); + aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); + map(bvm, constantOp, memref); + + return success(); + } +}; + +} // namespace arith_ext + +// TODO: Ops in the linalg dialect can directly implement this interface. +namespace linalg_ext { + +/// Helper function for LinalgOp bufferization. +/// When allocating a new buffer, analyze whether `op` wants to read form that +/// buffer. Only in that case, a copy of the result buffer may be needed. +static LogicalResult allocateBuffersForResults( + OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl &resultBuffers, BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + + // TODO: provide the proper interface to iterate on OpResults and get the + // matching OpOperands. + for (OpOperand *opOperand : op.getOutputOperands()) { + OpResult opResult = cast(op.getOperation()) + .getInplaceableOpResult(*opOperand); + assert(opResult && "could not find correspond OpResult"); + bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy); + if (!resultBuffer) + return failure(); + resultBuffers.push_back(resultBuffer); + } + + if (op->getNumResults()) + map(bvm, op->getResults(), resultBuffers); + + return success(); +} + +/// Generic conversion for any LinalgOp on tensors. +static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFns) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasTensorSemantics()) + return op->emitError() << "op does not have tensor semantics"; + + Location loc = op.getLoc(); + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumInputs()); + for (OpOperand *opOperand : op.getInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + newInputBuffers.push_back(lookup(bvm, opOperand->get())); + assert(newInputBuffers.back() && "missing buffer"); + } + SmallVector newOutputBuffers; + // Try to allocate new buffers depending on op's inplace semantics. + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, + aliasInfo, allocationFns))) + return failure(); + + // Clone the newly bufferized op. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(op); + op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); + + // Replace the results of the old op with the new output buffers. + if (op->getNumResults()) + map(bvm, op->getResults(), newOutputBuffers); + + // The original op will be DCE'd away later. + + return success(); +} + +template +struct LinalgOpInterface + : public BufferizableOpInterface::ExternalModel, + OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + auto genericOp = cast(op); + return genericOp.isInputTensor(&opOperand) || + genericOp.isInitTensor(&opOperand); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + auto genericOp = cast(op); + return genericOp.isOutputTensor(&opOperand); + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + auto genericOp = cast(op); + return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + auto genericOp = cast(op); + if (!opOperand.get().getType().isa()) + return OpResult(); + // For now assume inputs are never inplaceable. + // TODO: refine this. + if (opOperand.getOperandNumber() < genericOp.getNumInputs()) + return OpResult(); + int64_t outputOperandIndex = + opOperand.getOperandNumber() - genericOp.getNumInputs(); + int64_t numOutputBuffers = 0; + for (unsigned idx = 0; idx < outputOperandIndex; ++idx) + if (!genericOp.getOutputOperand(idx)->get().getType().isa()) + ++numOutputBuffers; + return genericOp->getResult(outputOperandIndex - numOutputBuffers); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + return bufferizeLinalgOp(b, cast(op), bvm, aliasInfo, + allocationFn); + } +}; + +struct InitTensorOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {}; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto initTensorOp = cast(op); + + // The InitTensorOp may have been eliminated. + if (initTensorOp->getUses().empty()) + return success(); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(initTensorOp); + + Value alloc = createNewAllocDeallocPairForShapedValue( + b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo, + allocationFn); + map(bvm, initTensorOp.result(), alloc); + return success(); + } +}; + +struct TiledLoopOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + // TiledLoop alone doesn't bufferize to a memory read, one of the uses of + // its matching bbArg may. + auto tiledLoopOp = cast(op); + return isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + // TiledLoop alone doesn't bufferize to a memory write, one of the uses of + // its matching bbArg may. + auto bufferizableOp = cast(op); + return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + // TODO: TiledLoopOp helper method to avoid leaking impl details. + auto tiledLoopOp = cast(op); + return {&op->getOpOperand(tiledLoopOp.getNumControlOperands() + + tiledLoopOp.getNumInputs() + + opResult.getResultNumber())}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + auto tiledLoopOp = cast(op); + return tiledLoopOp.getTiedOpResult(opOperand); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto tiledLoopOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // Allocate output buffers if needed, forward output tensor args to the + // terminator. + Operation *yieldOp = tiledLoopOp.getBody()->getTerminator(); + Block *body = tiledLoopOp.getBody(); + + // Take copies of the old input and output operands, so we can insert + // inplace easily. + auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs()); + auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs()); + + int numLoops = tiledLoopOp.getNumLoops(); + int numControlOperands = tiledLoopOp.getNumControlOperands(); + + // Add buffers for outputs and the corresponding block arguments. + // Keep separate iterators to increment without further leaking impl. + // details. Start with outputs to avoid interference from new input buffers. + int numNewOutputBuffers = 0; + int resultIndex = 0; + int oldOutputBBArgIndex = numLoops + oldInputs.size(); + int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size(); + int nextOutputOperandIndex = + numControlOperands + oldInputs.size() + oldOutputs.size(); + for (Value oldOutputTensor : oldOutputs) { + if (!oldOutputTensor.getType().isa()) { + // Skip and increment the old bbarg index only. + ++oldOutputBBArgIndex; + // Do not increment resultIndex as only tensors are returned. + // TODO: better interface to avoid leaking such impl details. + continue; + } + + assert(oldOutputTensor.getType().isa() && + "bufferizable output must be a ranked tensor"); + + const OpResult &opResult = tiledLoopOp->getResult(resultIndex); + OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); + if (!resultBuffer) + return failure(); + + // Insert mapping and aliasing info. + aliasInfo.createAliasInfoEntry(resultBuffer); + aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); + map(bvm, opResult, resultBuffer); + + // Insert new operand and bbArg. + tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer); + BlockArgument newBufferBBArg = + body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); + BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); + // Insert mapping and aliasing info. + aliasInfo.createAliasInfoEntry(newBufferBBArg); + aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); + map(bvm, oldTensorBBArg, newBufferBBArg); + + // Set operand of `linalg.yield` to the bbArg so it just canonicalizes + // away later. + yieldOperand.set(oldTensorBBArg); + + // Increment indices. + ++numNewOutputBuffers; + ++resultIndex; + ++oldOutputBBArgIndex; + ++nextOutputBBArgIndex; + ++nextOutputOperandIndex; + } + + // Add buffers for inputs and the corresponding block arguments. + // Keep separate iterators to increment without further leaking impl. + // details. + int numNewInputBuffers = 0; + int oldInputBBArgIndex = numLoops; + int nextInputBBArgIndex = numLoops + oldInputs.size(); + int nextInputOperandIndex = numControlOperands + oldInputs.size(); + for (Value oldInputTensor : oldInputs) { + if (!oldInputTensor.getType().isa()) { + // Skip and increment the old bbarg index only. + ++oldInputBBArgIndex; + continue; + } + + Value inputBuffer = lookup(bvm, oldInputTensor); + assert(inputBuffer && " missing buffer for operand"); + + // Insert new operand and bbArg. + tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer); + BlockArgument newBufferBBArg = + body->insertArgument(nextInputBBArgIndex, inputBuffer.getType()); + BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); + + // Insert mapping and aliasing info. + aliasInfo.createAliasInfoEntry(newBufferBBArg); + aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); + map(bvm, oldTensorBBArg, newBufferBBArg); + + // Increment indices. + ++numNewInputBuffers; + ++oldInputBBArgIndex; + ++nextInputBBArgIndex; + ++nextInputOperandIndex; + } + + // Update segment sizes. + // TODO: Helper method to avoid leaking impl details. + tiledLoopOp->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + b.getI32VectorAttr( + {numLoops, numLoops, numLoops, + static_cast(oldInputs.size()) + numNewInputBuffers, + static_cast(oldOutputs.size()) + numNewOutputBuffers})); + + return success(); + } +}; + +struct YieldOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto yieldOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + // Cannot create IR past a yieldOp. + b.setInsertionPoint(yieldOp); + + // No tensors -> success. + if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor)) + return success(); + // linalg::YieldOp nested under TiledLoop must just canonicalize. + if (yieldOp->getParentOfType()) + return success(); + llvm_unreachable("unexpected yieldOp"); + } +}; + +} // namespace linalg_ext + +namespace scf_ext { + +struct IfOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + auto ifOp = cast(op); + // Either one of the corresponding yield values from the then/else branches + // may alias with the result. + size_t resultNum = std::distance(op->getOpResults().begin(), + llvm::find(op->getOpResults(), opResult)); + return {&ifOp.thenYield()->getOpOperand(resultNum), + &ifOp.elseYield()->getOpOperand(resultNum)}; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto ifOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + for (OpResult opResult : ifOp->getResults()) { + if (!opResult.getType().isa()) + continue; + // TODO: Atm we bail on unranked TensorType because we don't know how to + // alloc an UnrankedMemRefType + its underlying ranked MemRefType. + assert(opResult.getType().isa() && + "unsupported unranked tensor"); + + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); + if (!resultBuffer) + return failure(); + + aliasInfo.createAliasInfoEntry(resultBuffer); + map(bvm, opResult, resultBuffer); + } + + return success(); + } +}; + +struct ForOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of + // its matching bbArg may. + auto forOp = cast(op); + return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + // Tensor iter_args of scf::ForOps are always considered as a write. This is + // to simplify the analysis. + // TODO: Consider doing sth. like isValueWritten. + return true; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + auto forOp = cast(op); + return {&forOp.getIterOpOperands()[opResult.getResultNumber()]}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + auto forOp = cast(op); + if (!opOperand.get().getType().isa()) + return OpResult(); + return forOp.getResultForOpOperand(opOperand); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto forOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + for (OpResult opResult : forOp->getResults()) { + if (!opResult.getType().isa()) + continue; + // TODO: Atm we bail on unranked TensorType because we don't know how to + // alloc an UnrankedMemRefType + its underlying ranked MemRefType. + assert(opResult.getType().isa() && + "unsupported unranked tensor"); + + // TODO: More general: Matching bbArg does not bufferize to a read. + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); + if (!resultBuffer) + return failure(); + + OpOperand &opOperand = forOp.getOpOperandForResult(opResult); + BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); + aliasInfo.createAliasInfoEntry(resultBuffer); + aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); + map(bvm, bbArg, resultBuffer); + map(bvm, opResult, resultBuffer); + } + + return success(); + } +}; + +struct YieldOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto yieldOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + // Cannot create IR past a yieldOp. + b.setInsertionPoint(yieldOp); + + if (auto execOp = dyn_cast(yieldOp->getParentOp())) { + if (execOp->getNumResults() != 0) + return execOp->emitError( + "expected result-less scf.execute_region containing op"); + return success(); + } + + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) + return success(); + + scf::ForOp forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp"); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + Value yieldedBuffer = lookup(bvm, operand.get()); + Value bbArgBuffer = lookup(bvm, bbArg); + if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, + bbArgBuffer)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + } + + // Buffers are equivalent so the work is already done and we just yield + // the bbArg so that it later canonicalizes away. + operand.set(bbArg); + } + return success(); + } +}; + +} // namespace scf_ext + +namespace std_ext { + +struct CallOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + // CallOpInterface alone doesn't bufferize to a memory read, one of the uses + // of the matching bbArg may. It is the responsibility of the caller to + // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be + // conservative. + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + // CallOpInterface alone doesn't bufferize to a memory write, one of the + // uses of the matching bbArg may. It is the responsibility of the caller to + // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be + // conservative. + return true; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + // TODO: Can we do better? + return {}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + // CallOpInterface is special, it needs to wait for the callee to be + // bufferized and needs to inspect the BufferAliasInfo object. It can't + // make a proper determination by itself and needs to be conservative. + return OpResult(); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + llvm_unreachable("CallOps are handled separately"); + return failure(); + } +}; + +struct ReturnOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto returnOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + // Cannot insert after returnOp. + b.setInsertionPoint(returnOp); + + assert(isa(returnOp->getParentOp()) && + "only support FuncOp parent for ReturnOp"); + for (OpOperand &operand : returnOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + Value v = lookup(bvm, operand.get()); + assert(v && "missing buffer for result"); + Value returnTensor = b.create(returnOp.getLoc(), v); + operand.set(returnTensor); + aliasInfo.insertNewBufferEquivalence(returnTensor, v); + map(bvm, returnTensor, v); + } + return success(); + } +}; + +} // namespace std_ext + +namespace tensor_ext { + +struct CastOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + bool bufferizesToAliasOnly(Operation *op, OpOperand &opOperand) const { + return true; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(0)}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return op->getResult(0); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto castOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(castOp); + + Value resultBuffer = + getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn); + if (!resultBuffer) + return failure(); + Type sourceType = resultBuffer.getType(); + auto rankedMemRefType = sourceType.dyn_cast(); + auto unrankedMemRefType = sourceType.dyn_cast(); + assert(rankedMemRefType || unrankedMemRefType); + Attribute memorySpace = rankedMemRefType + ? rankedMemRefType.getMemorySpace() + : unrankedMemRefType.getMemorySpace(); + TensorType tensorType = castOp.getResult().getType().cast(); + MemRefLayoutAttrInterface layout = + rankedMemRefType && tensorType.isa() + ? rankedMemRefType.getLayout() + : MemRefLayoutAttrInterface(); + Type memRefType = getContiguousOrUnrankedMemRefType( + castOp.getResult().getType(), layout, memorySpace); + Value res = + b.create(castOp.getLoc(), memRefType, resultBuffer); + aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); + map(bvm, castOp.getResult(), res); + return success(); + } +}; + +struct DimOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto dimOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(dimOp); + + if (dimOp.source().getType().isa()) { + Value v = lookup(bvm, dimOp.source()); + assert(v && "missing buffer"); + dimOp.result().replaceAllUsesWith( + b.create(dimOp.getLoc(), v, dimOp.index())); + } + return success(); + } +}; + +struct ExtractSliceOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + bool bufferizesToAliasOnly(Operation *op, OpOperand &opOperand) const { + return true; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(0) /*source*/}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return &opOperand == &op->getOpOperand(0) /*source*/ + ? op->getResult(0) + : OpResult(); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto extractSliceOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + LDBG("bufferize: " << *extractSliceOp << '\n'); + + Location loc = extractSliceOp.getLoc(); + // Bail if source was not bufferized. + Value srcMemref = lookup(bvm, extractSliceOp.source()); + if (!srcMemref) + return failure(); + auto srcMemrefType = srcMemref.getType().cast(); + auto dstTensorType = + extractSliceOp.result().getType().cast(); + + // If not inplaceable, alloc. + Value alloc; + auto inPlace = getInPlace(extractSliceOp->getResult(0)); + if (inPlace != InPlaceSpec::True) + alloc = createNewAllocDeallocPairForShapedValue( + b, loc, extractSliceOp.result(), aliasInfo, allocationFn); + + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(extractSliceOp); + + // Bufferize to subview. + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + dstTensorType.getRank(), srcMemrefType, + extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), + extractSliceOp.getMixedStrides()) + .cast(); + Value subView = b.create( + loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + // Insert new alias. + aliasInfo.insertNewBufferAlias(subView, srcMemref); + + /// If not inplaceable, copy. + if (alloc) { + // Do not copy if the copied data is never read. + if (isValueRead(extractSliceOp.result())) + b.create(extractSliceOp.getLoc(), subView, alloc); + subView = alloc; + } + + map(bvm, extractSliceOp.result(), subView); + return success(); + } +}; + +struct ExtractOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto extractOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(extractOp); + + Location loc = extractOp.getLoc(); + Value srcMemref = lookup(bvm, extractOp.tensor()); + Value l = b.create(loc, srcMemref, extractOp.indices()); + extractOp.replaceAllUsesWith(l); + return success(); + } +}; + +struct InsertSliceOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return &opOperand == &op->getOpOperand(1) /*dest*/; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(1) /*dest*/}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return &opOperand == &op->getOpOperand(1) /*dest*/ + ? op->getResult(0) + : OpResult(); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto insertSliceOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(insertSliceOp); + + LDBG("bufferize: " << *insertSliceOp << '\n'); + + Location loc = insertSliceOp.getLoc(); + // Since insert_slice arise from tiling and introducing loops, this + // case is generally a deal breaker. When used with loops, this ends up + // cloning the whole tensor on every single iteration and is a symptom + // of a catastrophically bad scheduling decision. + // TODO: be very loud about it or even consider failing the pass. + // Alloc a copy for `insertSliceOp.dest()`, it will become the result + // buffer. + Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm, + aliasInfo, allocationFn); + if (!dstMemref) + return failure(); + auto dstMemrefType = dstMemref.getType().cast(); + + Value srcMemref = lookup(bvm, insertSliceOp.source()); + if (!srcMemref) + return failure(); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + insertSliceOp.getSourceType().getRank(), dstMemrefType, + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()) + .cast(); + + // A copy of the source buffer is needed if either: + // - The producer of `source` is not inplace. This is the case where a + // slice is computed out of place into the inplace full tensor. + // - The result is not inplace. This is the case where the whole tensor is + // cloned and the clone needs to be updated. + auto inPlace = getInPlace(insertSliceOp->getResult(0)); + // TODO: Is this necessary? + if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp( + insertSliceOp) || + inPlace != InPlaceSpec::True) { + LDBG("insert_slice needs extra source copy: " << insertSliceOp.source() + << " -> copy\n"); + // Take a subview of the dst. + Value subView = b.create( + loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + // Insert new alias. + aliasInfo.insertNewBufferAlias(subView, dstMemref); + b.create(insertSliceOp.getLoc(), srcMemref, subView); + } + + map(bvm, insertSliceOp.result(), dstMemref); + + return success(); + } +}; + +} // namespace tensor_ext + +namespace vector_ext { + +struct TransferReadOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return false; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto transferReadOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + + if (transferReadOp.getShapedType().isa()) + return failure(); + + // TransferReadOp always reads from the bufferized op.source(). + Value v = lookup(bvm, transferReadOp.source()); + assert(v && "missing buffer"); + transferReadOp.sourceMutable().assign(v); + return success(); + } +}; + +struct TransferWriteOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return true; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(1)}; + } + + OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return op->getOpResult(0); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto writeOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + + if (writeOp.getShapedType().isa()) + return failure(); + + // Create a new transfer_write on buffer that doesn't have a return value. + // Leave the previous transfer_write to dead code as it still has uses at + // this point. + Value resultBuffer = + getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn); + if (!resultBuffer) + return failure(); + b.create( + writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), + writeOp.permutation_map(), + writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + map(bvm, op->getResult(0), resultBuffer); + + return success(); + } +}; + +} // namespace vector_ext + +namespace { + +/// Helper structure that iterates over all LinalgOps in `OpTys` and registers +/// the `BufferizableOpInterface` with each of them. +template +struct LinalgOpInterfaceHelper; + +template +struct LinalgOpInterfaceHelper { + static void registerOpInterface(DialectRegistry ®istry) { + registry.addOpInterface>(); + LinalgOpInterfaceHelper::registerOpInterface(registry); + } +}; + +template <> +struct LinalgOpInterfaceHelper<> { + static void registerOpInterface(DialectRegistry ®istry) {} +}; + +} // namespace + +void registerBufferiableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); + registry + .addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + + // Register all Linalg structured ops. `LinalgOp` is an interface and it is + // not possible to attach an external interface to an existing interface. + // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. + LinalgOpInterfaceHelper< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::registerOpInterface(registry); +} + +} // namespace linalg +} // namespace mlir