diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -84,38 +84,6 @@ LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, const BufferizationOptions &options); -namespace linalg_ext { - -struct InitTensorEliminationStep : public PostAnalysisStep { - /// Try to eliminate InitTensorOps inside `funcOp`. - /// - /// * `rewriteFunc` generates the replacement for the InitTensorOp. - /// * Only InitTensorOps that are anchored on a matching OpOperand as per - /// `anchorMatchFunc` are considered. "Anchored" means that there is a path - /// on the reverse SSA use-def chain, starting from the OpOperand and always - /// following the aliasing OpOperand, that eventually ends at a single - /// InitTensorOp. - /// * The result of `rewriteFunc` must usually be analyzed for inplacability. - /// This analysis can be skipped with `skipAnalysis`. - LogicalResult eliminateInitTensors( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, - std::function anchorMatchFunc, - std::function rewriteFunc, - SmallVector &newOps); -}; - -/// Try to eliminate InitTensorOps inside funcOp that are anchored on an -/// InsertSliceOp, i.e., if it is eventually inserted into another tensor -/// (and some other conditions are met). -struct InsertSliceAnchoredInitTensorEliminationStep - : public InitTensorEliminationStep { - LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo, - SmallVector &newOps) override; -}; - -} // namespace linalg_ext - } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -0,0 +1,52 @@ +#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H +#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H + +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" + +namespace mlir { + +class DialectRegistry; + +namespace linalg { +namespace comprehensive_bufferize { + +class BufferizationAliasInfo; + +namespace linalg_ext { + +struct InitTensorEliminationStep : public PostAnalysisStep { + /// Try to eliminate InitTensorOps inside `funcOp`. + /// + /// * `rewriteFunc` generates the replacement for the InitTensorOp. + /// * Only InitTensorOps that are anchored on a matching OpOperand as per + /// `anchorMatchFunc` are considered. "Anchored" means that there is a path + /// on the reverse SSA use-def chain, starting from the OpOperand and always + /// following the aliasing OpOperand, that eventually ends at a single + /// InitTensorOp. + /// * The result of `rewriteFunc` must usually be analyzed for inplacability. + /// This analysis can be skipped with `skipAnalysis`. + LogicalResult eliminateInitTensors( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, + std::function anchorMatchFunc, + std::function rewriteFunc, + SmallVector &newOps); +}; + +/// Try to eliminate InitTensorOps inside funcOp that are anchored on an +/// InsertSliceOp, i.e., if it is eventually inserted into another tensor +/// (and some other conditions are met). +struct InsertSliceAnchoredInitTensorEliminationStep + : public InitTensorEliminationStep { + LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo, + SmallVector &newOps) override; +}; + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace linalg_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_OPTIONAL_SOURCES BufferizableOpInterface.cpp ComprehensiveBufferize.cpp + LinalgInterfaceImpl.cpp ) add_mlir_dialect_library(MLIRBufferizableOpInterface @@ -13,15 +14,25 @@ MLIRIR ) +add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl + LinalgInterfaceImpl.cpp + + LINK_LIBS PUBLIC + MLIRBufferizableOpInterface + MLIRIR + MLIRLinalg + MLIRTensor +) + add_mlir_dialect_library(MLIRComprehensiveBufferize ComprehensiveBufferize.cpp LINK_LIBS PUBLIC + MLIRAffine MLIRBufferizableOpInterface MLIRInferTypeOpInterface MLIRIR MLIRMemRef - MLIRLinalg MLIRSCF MLIRStandard MLIRStandardOpsTransforms diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -109,14 +109,16 @@ #include +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AsmState.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -1590,130 +1592,6 @@ } } -/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced -/// with the the result of `rewriteFunc` if it is anchored on a matching -/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def -/// chain, starting from the OpOperand and always following the aliasing -/// OpOperand, that eventually ends at a single InitTensorOp. -LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: - InitTensorEliminationStep::eliminateInitTensors( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo, - std::function anchorMatchFunc, - std::function rewriteFunc, - SmallVector &newOps) { - OpBuilder b(funcOp->getContext()); - - WalkResult status = funcOp->walk([&](Operation *op) { - for (OpOperand &operand : op->getOpOperands()) { - // Is this a matching OpOperand? - if (!anchorMatchFunc(operand)) - continue; - - SetVector maybeInitTensor = - findValueInReverseUseDefChain(operand.get(), [&](Value val) { - // Continue traversal until this function returns true. - OpResult opResult = val.dyn_cast(); - if (!opResult) - return true; - if (!aliasInfo.isInPlace(opResult)) - return true; - // Only equivalent tensors are supported at the moment. - // TODO: Support cases such as extract_slice(init_tensor). - SmallVector opOperands = - getAliasingOpOperand(opResult); - if (!llvm::all_of(opOperands, [](OpOperand *operand) { - return bufferRelation(*operand) == BufferRelation::Equivalent; - })) - return true; - return false; - }); - - // Replace only if the reverse use-def chain ends at exactly one - // InitTensorOp. - if (maybeInitTensor.size() != 1 || - !maybeInitTensor.front().getDefiningOp()) - return WalkResult::skip(); - Value initTensor = maybeInitTensor.front(); - - // Create a replacement for the InitTensorOp. - b.setInsertionPoint(initTensor.getDefiningOp()); - Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); - if (!replacement) - continue; - - // Uses of the InitTensorOp are replaced here, but the op is not deleted. - // InitTensorOps without uses are ignored by the bufferization. - initTensor.replaceAllUsesWith(replacement); - aliasInfo.createAliasInfoEntry(replacement); - aliasInfo.unionAliasSets(initTensor, replacement); - aliasInfo.unionEquivalenceClasses(initTensor, replacement); - - // Register replacement ops. - if (Operation *newOp = replacement.getDefiningOp()) - newOps.push_back(newOp); - } - - // Advance to the next operation. - return WalkResult::advance(); - }); - - return failure(status.wasInterrupted()); -} - -/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be -/// eliminated if it is eventually inserted into another tensor (and some other -/// conditions are met). -/// -/// E.g.: -/// %0 = linalg.init_tensor -/// %1 = linalg.fill(%cst, %0) {inplace = [true]} -/// %2 = tensor.insert_slice %1 into %t[10][20][1] -/// -/// InitTensorOp elimination will try to fill %t inplace instead of filling a -/// new allocation %0 and inserting it into %t. This is done by replacing the -/// InitTensorOp with: -/// -/// %0 = tensor.extract_slice %t[10][20][1] -/// -/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets -/// those bufferize inplace in the absence of other conflicts. -/// -/// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert -/// source's reverse use-def chain is eliminated if: -/// * The InsertSliceOp was decided to bufferize inplace. -/// * On the reverse use-def chain path from the InsertSliceOp to the -/// InitTensorOp, all ops were decided to bufferize inplace and the buffer -/// relation is "equivalent" (TODO: can be relaxed if needed). -/// * The reverse use-def chain has exactly one end, which is the InitTensorOp. -/// -/// Note that the newly inserted ExtractSliceOp may have to bufferize -/// out-of-place due to RaW conflicts. -LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: - InsertSliceAnchoredInitTensorEliminationStep::run( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo, SmallVector &newOps) { - return eliminateInitTensors( - funcOp, aliasInfo, domInfo, - [&](OpOperand &operand) { - auto insertSliceOp = dyn_cast(operand.getOwner()); - if (!insertSliceOp) - return false; - // Only inplace bufferized InsertSliceOps are eligible. - if (!aliasInfo.isInPlace(insertSliceOp->getOpResult(0))) - return false; - return &operand == &insertSliceOp->getOpOperand(0) /*source*/; - }, - [](OpBuilder &b, Location loc, OpOperand &operand) { - auto insertSliceOp = cast(operand.getOwner()); - auto extractOp = b.create( - loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(), - insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); - return extractOp.result(); - }, - newOps); -} - #ifndef NDEBUG /// Assert that the current bufferization decisions are consistent. static void checkAliasInfoConsistency(FuncOp funcOp, @@ -1914,368 +1792,6 @@ } // namespace arith_ext -// TODO: Ops in the linalg dialect can directly implement this interface. -namespace linalg_ext { - -/// Helper function for LinalgOp bufferization. -/// When allocating a new buffer, analyze whether `op` wants to read form that -/// buffer. Only in that case, a copy of the result buffer may be needed. -static LogicalResult -allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, - BufferizationState &state) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // TODO: provide the proper interface to iterate on OpResults and get the - // matching OpOperands. - for (OpOperand *opOperand : op.getOutputOperands()) { - OpResult opResult = cast(op.getOperation()) - .getAliasingOpResult(*opOperand); - assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = getResultBuffer(b, opResult, state); - if (!resultBuffer) - return failure(); - resultBuffers.push_back(resultBuffer); - } - - if (op->getNumResults()) - state.mapBuffer(op->getResults(), resultBuffers); - - return success(); -} - -/// Generic conversion for any LinalgOp on tensors. -static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op, - BufferizationState &state) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need - // basis. - if (!op.hasTensorSemantics()) - return op->emitError() << "op does not have tensor semantics"; - - Location loc = op.getLoc(); - SmallVector newInputBuffers; - newInputBuffers.reserve(op.getNumInputs()); - for (OpOperand *opOperand : op.getInputOperands()) { - if (op.isScalar(opOperand)) { - newInputBuffers.push_back(opOperand->get()); - continue; - } - newInputBuffers.push_back(state.lookupBuffer(opOperand->get())); - } - SmallVector newOutputBuffers; - // Try to allocate new buffers depending on op's inplace semantics. - if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state))) - return failure(); - - // Clone the newly bufferized op. - SmallVector newOperands = newInputBuffers; - newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); - - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(op); - op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); - - // Replace the results of the old op with the new output buffers. - if (op->getNumResults()) - state.mapBuffer(op->getResults(), newOutputBuffers); - - // The original op will be DCE'd away later. - - return success(); -} - -template -struct LinalgOpInterface - : public BufferizableOpInterface::ExternalModel, - OpTy> { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - auto genericOp = cast(op); - return (genericOp.isInputTensor(&opOperand) || - genericOp.isInitTensor(&opOperand)) && - genericOp.payloadUsesValueFromOperand(&opOperand); - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - auto genericOp = cast(op); - return genericOp.isOutputTensor(&opOperand); - } - - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - auto genericOp = cast(op); - return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]}; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - auto genericOp = cast(op); - if (!opOperand.get().getType().isa()) - return OpResult(); - // For now assume inputs are never inplaceable. - // TODO: refine this. - if (opOperand.getOperandNumber() < genericOp.getNumInputs()) - return OpResult(); - int64_t outputOperandIndex = - opOperand.getOperandNumber() - genericOp.getNumInputs(); - int64_t numOutputBuffers = 0; - for (unsigned idx = 0; idx < outputOperandIndex; ++idx) - if (!genericOp.getOutputOperand(idx)->get().getType().isa()) - ++numOutputBuffers; - return genericOp->getResult(outputOperandIndex - numOutputBuffers); - } - - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - return bufferizeLinalgOp(b, cast(op), state); - } -}; - -struct InitTensorOpInterface - : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {}; - } - - bool isMemoryWrite(Operation *op, OpResult opResult) const { - // InitTensorOps allocate but do not write. - return false; - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto initTensorOp = cast(op); - - // The InitTensorOp may have been eliminated. - if (initTensorOp->getUses().empty()) - return success(); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(initTensorOp); - - Value alloc = createNewAllocDeallocPairForShapedValue( - b, initTensorOp->getLoc(), initTensorOp.result(), state); - state.mapBuffer(initTensorOp.result(), alloc); - return success(); - } -}; - -struct TiledLoopOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - // TiledLoop alone doesn't bufferize to a memory read, one of the uses of - // its matching bbArg may. - auto tiledLoopOp = cast(op); - return isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - // TiledLoop alone doesn't bufferize to a memory write, one of the uses of - // its matching bbArg may. - auto bufferizableOp = cast(op); - return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); - } - - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - // TODO: TiledLoopOp helper method to avoid leaking impl details. - auto tiledLoopOp = cast(op); - return {&op->getOpOperand(tiledLoopOp.getNumControlOperands() + - tiledLoopOp.getNumInputs() + - opResult.getResultNumber())}; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - auto tiledLoopOp = cast(op); - return tiledLoopOp.getTiedOpResult(opOperand); - } - - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; - } - - bool isWritable(Operation *op, Value value) const { - // Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed - // inplace from the perspective of ops nested under: - // 1. Either the matching iter operand is not bufferized inplace and an - // alloc + optional copy makes the bbArg itself inplaceable. - // 2. Or the matching iter operand is bufferized inplace and bbArg just - // bufferizes to that too. - return true; - } - - bool isAllocationHoistingBarrier(Operation *op) const { return true; } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto tiledLoopOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - // Allocate output buffers if needed, forward output tensor args to the - // terminator. - Operation *yieldOp = tiledLoopOp.getBody()->getTerminator(); - Block *body = tiledLoopOp.getBody(); - - // Take copies of the old input and output operands, so we can insert - // inplace easily. - auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs()); - auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs()); - - int numLoops = tiledLoopOp.getNumLoops(); - int numControlOperands = tiledLoopOp.getNumControlOperands(); - - // Add buffers for outputs and the corresponding block arguments. - // Keep separate iterators to increment without further leaking impl. - // details. Start with outputs to avoid interference from new input buffers. - int numNewOutputBuffers = 0; - int resultIndex = 0; - int oldOutputBBArgIndex = numLoops + oldInputs.size(); - int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size(); - int nextOutputOperandIndex = - numControlOperands + oldInputs.size() + oldOutputs.size(); - for (Value oldOutputTensor : oldOutputs) { - if (!oldOutputTensor.getType().isa()) { - // Skip and increment the old bbarg index only. - ++oldOutputBBArgIndex; - // Do not increment resultIndex as only tensors are returned. - // TODO: better interface to avoid leaking such impl details. - continue; - } - - assert(oldOutputTensor.getType().isa() && - "bufferizable output must be a ranked tensor"); - - const OpResult &opResult = tiledLoopOp->getResult(resultIndex); - OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = getResultBuffer(b, opResult, state); - if (!resultBuffer) - return failure(); - - // Insert mapping and aliasing info. - state.aliasInfo.createAliasInfoEntry(resultBuffer); - state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); - state.mapBuffer(opResult, resultBuffer); - - // Insert new operand and bbArg. - tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer); - BlockArgument newBufferBBArg = - body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); - BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); - // Insert mapping and aliasing info. - state.aliasInfo.createAliasInfoEntry(newBufferBBArg); - state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, - newBufferBBArg); - state.mapBuffer(oldTensorBBArg, newBufferBBArg); - - // Set operand of `linalg.yield` to the bbArg so it just canonicalizes - // away later. - yieldOperand.set(oldTensorBBArg); - - // Increment indices. - ++numNewOutputBuffers; - ++resultIndex; - ++oldOutputBBArgIndex; - ++nextOutputBBArgIndex; - ++nextOutputOperandIndex; - } - - // Add buffers for inputs and the corresponding block arguments. - // Keep separate iterators to increment without further leaking impl. - // details. - int numNewInputBuffers = 0; - int oldInputBBArgIndex = numLoops; - int nextInputBBArgIndex = numLoops + oldInputs.size(); - int nextInputOperandIndex = numControlOperands + oldInputs.size(); - for (Value oldInputTensor : oldInputs) { - if (!oldInputTensor.getType().isa()) { - // Skip and increment the old bbarg index only. - ++oldInputBBArgIndex; - continue; - } - - Value inputBuffer = state.lookupBuffer(oldInputTensor); - - // Insert new operand and bbArg. - tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer); - BlockArgument newBufferBBArg = - body->insertArgument(nextInputBBArgIndex, inputBuffer.getType()); - BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); - - // Insert mapping and aliasing info. - state.aliasInfo.createAliasInfoEntry(newBufferBBArg); - state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, - newBufferBBArg); - state.mapBuffer(oldTensorBBArg, newBufferBBArg); - - // Increment indices. - ++numNewInputBuffers; - ++oldInputBBArgIndex; - ++nextInputBBArgIndex; - ++nextInputOperandIndex; - } - - // Update segment sizes. - // TODO: Helper method to avoid leaking impl details. - tiledLoopOp->setAttr( - TiledLoopOp::getOperandSegmentSizeAttr(), - b.getI32VectorAttr( - {numLoops, numLoops, numLoops, - static_cast(oldInputs.size()) + numNewInputBuffers, - static_cast(oldOutputs.size()) + numNewOutputBuffers})); - - return success(); - } -}; - -struct YieldOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - return false; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto yieldOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot create IR past a yieldOp. - b.setInsertionPoint(yieldOp); - - // No tensors -> success. - if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor)) - return success(); - // linalg::YieldOp nested under TiledLoop must just canonicalize. - if (yieldOp->getParentOfType()) - return success(); - llvm_unreachable("unexpected yieldOp"); - } -}; - -} // namespace linalg_ext - namespace scf_ext { struct IfOpInterface @@ -3016,33 +2532,8 @@ } // namespace vector_ext -namespace { - -/// Helper structure that iterates over all LinalgOps in `OpTys` and registers -/// the `BufferizableOpInterface` with each of them. -template struct LinalgOpInterfaceHelper; - -template -struct LinalgOpInterfaceHelper { - static void registerOpInterface(DialectRegistry ®istry) { - registry.addOpInterface>(); - LinalgOpInterfaceHelper::registerOpInterface(registry); - } -}; - -template <> struct LinalgOpInterfaceHelper<> { - static void registerOpInterface(DialectRegistry ®istry) {} -}; - -} // namespace - void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); - registry.addOpInterface(); - registry - .addOpInterface(); - registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); @@ -3066,14 +2557,6 @@ AllocationHoistingBarrierOnly>(); registry.addOpInterface>(); - - // Register all Linalg structured ops. `LinalgOp` is an interface and it is - // not possible to attach an external interface to an existing interface. - // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. - LinalgOpInterfaceHelper< -#define GET_OP_LIST -#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" - >::registerOpInterface(registry); } } // namespace comprehensive_bufferize diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -0,0 +1,540 @@ +//===- LinalgInterfaceImpl.cpp - Linalg Impl. of BufferizableOpInterface --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace linalg; +using namespace comprehensive_bufferize; + +namespace { + +// TODO: Ops in the linalg dialect can directly implement this interface. + +/// Helper function for LinalgOp bufferization. +/// When allocating a new buffer, analyze whether `op` wants to read form that +/// buffer. Only in that case, a copy of the result buffer may be needed. +static LogicalResult +allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl &resultBuffers, + BufferizationState &state) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + + // TODO: provide the proper interface to iterate on OpResults and get the + // matching OpOperands. + for (OpOperand *opOperand : op.getOutputOperands()) { + OpResult opResult = cast(op.getOperation()) + .getAliasingOpResult(*opOperand); + assert(opResult && "could not find correspond OpResult"); + Value resultBuffer = getResultBuffer(b, opResult, state); + if (!resultBuffer) + return failure(); + resultBuffers.push_back(resultBuffer); + } + + if (op->getNumResults()) + state.mapBuffer(op->getResults(), resultBuffers); + + return success(); +} + +/// Generic conversion for any LinalgOp on tensors. +static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op, + BufferizationState &state) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasTensorSemantics()) + return op->emitError() << "op does not have tensor semantics"; + + Location loc = op.getLoc(); + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumInputs()); + for (OpOperand *opOperand : op.getInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + newInputBuffers.push_back(state.lookupBuffer(opOperand->get())); + } + SmallVector newOutputBuffers; + // Try to allocate new buffers depending on op's inplace semantics. + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state))) + return failure(); + + // Clone the newly bufferized op. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(op); + op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); + + // Replace the results of the old op with the new output buffers. + if (op->getNumResults()) + state.mapBuffer(op->getResults(), newOutputBuffers); + + // The original op will be DCE'd away later. + + return success(); +} + +template +struct LinalgOpInterface + : public BufferizableOpInterface::ExternalModel, + OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + auto genericOp = cast(op); + return (genericOp.isInputTensor(&opOperand) || + genericOp.isInitTensor(&opOperand)) && + genericOp.payloadUsesValueFromOperand(&opOperand); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + auto genericOp = cast(op); + return genericOp.isOutputTensor(&opOperand); + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + auto genericOp = cast(op); + return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + auto genericOp = cast(op); + if (!opOperand.get().getType().isa()) + return OpResult(); + // For now assume inputs are never inplaceable. + // TODO: refine this. + if (opOperand.getOperandNumber() < genericOp.getNumInputs()) + return OpResult(); + int64_t outputOperandIndex = + opOperand.getOperandNumber() - genericOp.getNumInputs(); + int64_t numOutputBuffers = 0; + for (unsigned idx = 0; idx < outputOperandIndex; ++idx) + if (!genericOp.getOutputOperand(idx)->get().getType().isa()) + ++numOutputBuffers; + return genericOp->getResult(outputOperandIndex - numOutputBuffers); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + return bufferizeLinalgOp(b, cast(op), state); + } +}; + +struct InitTensorOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {}; + } + + bool isMemoryWrite(Operation *op, OpResult opResult) const { + // InitTensorOps allocate but do not write. + return false; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto initTensorOp = cast(op); + + // The InitTensorOp may have been eliminated. + if (initTensorOp->getUses().empty()) + return success(); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(initTensorOp); + + Value alloc = state.allocationFns.createAllocDeallocFn( + b, initTensorOp->getLoc(), initTensorOp.result(), state); + state.mapBuffer(initTensorOp.result(), alloc); + return success(); + } +}; + +struct TiledLoopOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + // TiledLoop alone doesn't bufferize to a memory read, one of the uses of + // its matching bbArg may. + auto tiledLoopOp = cast(op); + return isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + // TiledLoop alone doesn't bufferize to a memory write, one of the uses of + // its matching bbArg may. + auto bufferizableOp = cast(op); + return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + // TODO: TiledLoopOp helper method to avoid leaking impl details. + auto tiledLoopOp = cast(op); + return {&op->getOpOperand(tiledLoopOp.getNumControlOperands() + + tiledLoopOp.getNumInputs() + + opResult.getResultNumber())}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + auto tiledLoopOp = cast(op); + return tiledLoopOp.getTiedOpResult(opOperand); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + bool isWritable(Operation *op, Value value) const { + // Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed + // inplace from the perspective of ops nested under: + // 1. Either the matching iter operand is not bufferized inplace and an + // alloc + optional copy makes the bbArg itself inplaceable. + // 2. Or the matching iter operand is bufferized inplace and bbArg just + // bufferizes to that too. + return true; + } + + bool isAllocationHoistingBarrier(Operation *op) const { return true; } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto tiledLoopOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // Allocate output buffers if needed, forward output tensor args to the + // terminator. + Operation *yieldOp = tiledLoopOp.getBody()->getTerminator(); + Block *body = tiledLoopOp.getBody(); + + // Take copies of the old input and output operands, so we can insert + // inplace easily. + auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs()); + auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs()); + + int numLoops = tiledLoopOp.getNumLoops(); + int numControlOperands = tiledLoopOp.getNumControlOperands(); + + // Add buffers for outputs and the corresponding block arguments. + // Keep separate iterators to increment without further leaking impl. + // details. Start with outputs to avoid interference from new input buffers. + int numNewOutputBuffers = 0; + int resultIndex = 0; + int oldOutputBBArgIndex = numLoops + oldInputs.size(); + int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size(); + int nextOutputOperandIndex = + numControlOperands + oldInputs.size() + oldOutputs.size(); + for (Value oldOutputTensor : oldOutputs) { + if (!oldOutputTensor.getType().isa()) { + // Skip and increment the old bbarg index only. + ++oldOutputBBArgIndex; + // Do not increment resultIndex as only tensors are returned. + // TODO: better interface to avoid leaking such impl details. + continue; + } + + assert(oldOutputTensor.getType().isa() && + "bufferizable output must be a ranked tensor"); + + const OpResult &opResult = tiledLoopOp->getResult(resultIndex); + OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); + Value resultBuffer = getResultBuffer(b, opResult, state); + if (!resultBuffer) + return failure(); + + // Insert mapping and aliasing info. + state.aliasInfo.createAliasInfoEntry(resultBuffer); + state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); + state.mapBuffer(opResult, resultBuffer); + + // Insert new operand and bbArg. + tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer); + BlockArgument newBufferBBArg = + body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); + BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); + // Insert mapping and aliasing info. + state.aliasInfo.createAliasInfoEntry(newBufferBBArg); + state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, + newBufferBBArg); + state.mapBuffer(oldTensorBBArg, newBufferBBArg); + + // Set operand of `linalg.yield` to the bbArg so it just canonicalizes + // away later. + yieldOperand.set(oldTensorBBArg); + + // Increment indices. + ++numNewOutputBuffers; + ++resultIndex; + ++oldOutputBBArgIndex; + ++nextOutputBBArgIndex; + ++nextOutputOperandIndex; + } + + // Add buffers for inputs and the corresponding block arguments. + // Keep separate iterators to increment without further leaking impl. + // details. + int numNewInputBuffers = 0; + int oldInputBBArgIndex = numLoops; + int nextInputBBArgIndex = numLoops + oldInputs.size(); + int nextInputOperandIndex = numControlOperands + oldInputs.size(); + for (Value oldInputTensor : oldInputs) { + if (!oldInputTensor.getType().isa()) { + // Skip and increment the old bbarg index only. + ++oldInputBBArgIndex; + continue; + } + + Value inputBuffer = state.lookupBuffer(oldInputTensor); + + // Insert new operand and bbArg. + tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer); + BlockArgument newBufferBBArg = + body->insertArgument(nextInputBBArgIndex, inputBuffer.getType()); + BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); + + // Insert mapping and aliasing info. + state.aliasInfo.createAliasInfoEntry(newBufferBBArg); + state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, + newBufferBBArg); + state.mapBuffer(oldTensorBBArg, newBufferBBArg); + + // Increment indices. + ++numNewInputBuffers; + ++oldInputBBArgIndex; + ++nextInputBBArgIndex; + ++nextInputOperandIndex; + } + + // Update segment sizes. + // TODO: Helper method to avoid leaking impl details. + tiledLoopOp->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + b.getI32VectorAttr( + {numLoops, numLoops, numLoops, + static_cast(oldInputs.size()) + numNewInputBuffers, + static_cast(oldOutputs.size()) + numNewOutputBuffers})); + + return success(); + } +}; + +struct YieldOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto yieldOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + // Cannot create IR past a yieldOp. + b.setInsertionPoint(yieldOp); + + // No tensors -> success. + if (!llvm::any_of(yieldOp.getOperandTypes(), + [](Type t) { return t.isa(); })) + return success(); + // linalg::YieldOp nested under TiledLoop must just canonicalize. + if (yieldOp->getParentOfType()) + return success(); + llvm_unreachable("unexpected yieldOp"); + } +}; + +/// Helper structure that iterates over all LinalgOps in `OpTys` and registers +/// the `BufferizableOpInterface` with each of them. +template +struct LinalgOpInterfaceHelper; + +template +struct LinalgOpInterfaceHelper { + static void registerOpInterface(DialectRegistry ®istry) { + registry.addOpInterface>(); + LinalgOpInterfaceHelper::registerOpInterface(registry); + } +}; + +template <> +struct LinalgOpInterfaceHelper<> { + static void registerOpInterface(DialectRegistry ®istry) {} +}; + +} // namespace + +/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced +/// with the the result of `rewriteFunc` if it is anchored on a matching +/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def +/// chain, starting from the OpOperand and always following the aliasing +/// OpOperand, that eventually ends at a single InitTensorOp. +LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: + InitTensorEliminationStep::eliminateInitTensors( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo, + std::function anchorMatchFunc, + std::function rewriteFunc, + SmallVector &newOps) { + OpBuilder b(funcOp->getContext()); + + WalkResult status = funcOp->walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + // Is this a matching OpOperand? + if (!anchorMatchFunc(operand)) + continue; + + SetVector maybeInitTensor = + findValueInReverseUseDefChain(operand.get(), [&](Value val) { + // Continue traversal until this function returns true. + OpResult opResult = val.dyn_cast(); + if (!opResult) + return true; + if (!aliasInfo.isInPlace(opResult)) + return true; + // Only equivalent tensors are supported at the moment. + // TODO: Support cases such as extract_slice(init_tensor). + SmallVector opOperands = + getAliasingOpOperand(opResult); + if (!llvm::all_of(opOperands, [](OpOperand *operand) { + return bufferRelation(*operand) == BufferRelation::Equivalent; + })) + return true; + return false; + }); + + // Replace only if the reverse use-def chain ends at exactly one + // InitTensorOp. + if (maybeInitTensor.size() != 1 || + !maybeInitTensor.front().getDefiningOp()) + return WalkResult::skip(); + Value initTensor = maybeInitTensor.front(); + + // Create a replacement for the InitTensorOp. + b.setInsertionPoint(initTensor.getDefiningOp()); + Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); + if (!replacement) + continue; + + // Uses of the InitTensorOp are replaced here, but the op is not deleted. + // InitTensorOps without uses are ignored by the bufferization. + initTensor.replaceAllUsesWith(replacement); + aliasInfo.createAliasInfoEntry(replacement); + aliasInfo.unionAliasSets(initTensor, replacement); + aliasInfo.unionEquivalenceClasses(initTensor, replacement); + + // Register replacement ops. + if (Operation *newOp = replacement.getDefiningOp()) + newOps.push_back(newOp); + } + + // Advance to the next operation. + return WalkResult::advance(); + }); + + return failure(status.wasInterrupted()); +} + +/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be +/// eliminated if it is eventually inserted into another tensor (and some other +/// conditions are met). +/// +/// E.g.: +/// %0 = linalg.init_tensor +/// %1 = linalg.fill(%cst, %0) {inplace = [true]} +/// %2 = tensor.insert_slice %1 into %t[10][20][1] +/// +/// InitTensorOp elimination will try to fill %t inplace instead of filling a +/// new allocation %0 and inserting it into %t. This is done by replacing the +/// InitTensorOp with: +/// +/// %0 = tensor.extract_slice %t[10][20][1] +/// +/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets +/// those bufferize inplace in the absence of other conflicts. +/// +/// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert +/// source's reverse use-def chain is eliminated if: +/// * The InsertSliceOp was decided to bufferize inplace. +/// * On the reverse use-def chain path from the InsertSliceOp to the +/// InitTensorOp, all ops were decided to bufferize inplace and the buffer +/// relation is "equivalent" (TODO: can be relaxed if needed). +/// * The reverse use-def chain has exactly one end, which is the InitTensorOp. +/// +/// Note that the newly inserted ExtractSliceOp may have to bufferize +/// out-of-place due to RaW conflicts. +LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: + InsertSliceAnchoredInitTensorEliminationStep::run( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo, SmallVector &newOps) { + return eliminateInitTensors( + funcOp, aliasInfo, domInfo, + [&](OpOperand &operand) { + auto insertSliceOp = + dyn_cast(operand.getOwner()); + if (!insertSliceOp) + return false; + // Only inplace bufferized InsertSliceOps are eligible. + if (!aliasInfo.isInPlace(insertSliceOp->getOpResult(0))) + return false; + return &operand == &insertSliceOp->getOpOperand(0) /*source*/; + }, + [](OpBuilder &b, Location loc, OpOperand &operand) { + auto insertSliceOp = cast(operand.getOwner()); + auto extractOp = b.create( + loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + return extractOp.result(); + }, + newOps); +} + +void mlir::linalg::comprehensive_bufferize::linalg_ext:: + registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + + // Register all Linalg structured ops. `LinalgOp` is an interface and it is + // not possible to attach an external interface to an existing interface. + // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. + LinalgOpInterfaceHelper< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::registerOpInterface(registry); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,8 +38,9 @@ MLIRInferTypeOpInterface MLIRIR MLIRMemRef - MLIRLinalgAnalysis MLIRLinalg + MLIRLinalgAnalysis + MLIRLinalgBufferizableOpInterfaceImpl MLIRLinalgUtils MLIRSCF MLIRSCFTransforms diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -36,6 +37,7 @@ tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect, arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); registerBufferizableOpInterfaceExternalModels(registry); + linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); } }; } // end namespace diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6284,6 +6284,26 @@ ], ) +cc_library( + name = "LinalgBufferizableOpInterfaceImpl", + srcs = [ + "lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h", + ], + includes = ["include"], + deps = [ + ":BufferizableOpInterface", + ":IR", + ":LinalgOps", + ":LinalgStructuredOpsIncGen", + ":Support", + ":TensorDialect", + "//llvm:Support", + ], +) + td_library( name = "LinalgDocTdFiles", srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"], @@ -6491,6 +6511,7 @@ ":DialectUtils", ":IR", ":InferTypeOpInterface", + ":LinalgBufferizableOpInterfaceImpl", ":LinalgOps", ":LinalgPassIncGen", ":LinalgStructuredOpsIncGen", @@ -6521,13 +6542,12 @@ ], includes = ["include"], deps = [ + ":Affine", ":ArithmeticDialect", ":BufferizableOpInterface", ":DialectUtils", ":IR", ":InferTypeOpInterface", - ":LinalgOps", - ":LinalgStructuredOpsIncGen", ":MemRefDialect", ":Pass", ":SCFDialect",