diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -70,7 +70,7 @@ /// Run the post analysis step. This function may modify the IR, but must keep /// `aliasInfo` (inside `state`) consistent. Newly created operations and /// operations that should be re-analyzed must be stored in `newOps`. - virtual LogicalResult run(FuncOp funcOp, BufferizationState &state, + virtual LogicalResult run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) = 0; }; @@ -299,9 +299,8 @@ /// directly return a mapped buffer or allocate a new brand new buffer. class BufferizationState { public: - BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options) - : aliasInfo(moduleOp), options(options), - builder(moduleOp->getContext()) {} + BufferizationState(Operation *op, const BufferizationOptions &options) + : aliasInfo(op), options(options), builder(op->getContext()) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; @@ -365,7 +364,7 @@ private: friend LogicalResult - runComprehensiveBufferize(FuncOp funcOp, const BufferizationOptions &options, + runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, BufferizationState &state, const PostAnalysisStepList &extraSteps); diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -196,7 +196,7 @@ before returning. Otherwise, nested ops will not be bufferized. This method will never be called on ops that do not have at least one - tensor operand or result. + tensor operand/result or a region. }], /*retType=*/"LogicalResult", /*methodName=*/"bufferize", diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -20,15 +20,16 @@ struct BufferizationState; struct PostAnalysisStep; -/// Bufferize the given function. Does not bufferize the function boundary. -/// Reuses an existing BufferizationState object. -// TODO: This function is meant to be called from ModuleBufferize and not can -// not yet be called standalone. +/// Bufferize the given operation. Reuses an existing BufferizationState object. LogicalResult runComprehensiveBufferize( - FuncOp funcOp, const BufferizationOptions &options, + Operation *op, const BufferizationOptions &options, BufferizationState &state, const std::vector> &extraSteps); +/// Bufferize the given operation. +LogicalResult runComprehensiveBufferize(Operation *op, + const BufferizationOptions &options); + } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -23,7 +23,7 @@ namespace linalg_ext { struct InitTensorEliminationStep : public PostAnalysisStep { - /// Try to eliminate InitTensorOps inside `funcOp`. + /// Try to eliminate InitTensorOps inside `op`. /// /// * `rewriteFunc` generates the replacement for the InitTensorOp. /// * Only InitTensorOps that are anchored on a matching OpOperand as per @@ -34,19 +34,19 @@ /// * The result of `rewriteFunc` must usually be analyzed for inplacability. /// This analysis can be skipped with `skipAnalysis`. LogicalResult eliminateInitTensors( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps); }; -/// Try to eliminate InitTensorOps inside funcOp that are anchored on an +/// Try to eliminate InitTensorOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). struct InsertSliceAnchoredInitTensorEliminationStep : public InitTensorEliminationStep { - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -22,7 +22,7 @@ /// Equivalence analysis for scf.for. Raise an error if iter_args are not /// equivalent to their corresponding loop yield values. struct AssertDestinationPassingStyle : public PostAnalysisStep { - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -20,7 +20,7 @@ namespace tensor_ext { struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep { - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -425,14 +425,11 @@ auto isaTensor = [](Type t) { return t.isa(); }; bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + bool hasRegions = !op->getRegions().empty(); - // No tensor results or operands: Simply bufferize all nested ops. - if (!hasTensorResult && !hasTensorOperand) { - for (Region ®ion : op->getRegions()) - if (failed(bufferize(®ion, state))) - return failure(); + // No tensor results/operands or regions. We are done. + if (!hasTensorResult && !hasTensorOperand && !hasRegions) return success(); - } // Bufferize using `BufferizableOpInterface`. Interface implementations are // responsible for bufferizing nested ops. @@ -449,6 +446,8 @@ for (OpOperand &operand : op->getOpOperands()) { if (operand.get().getType().isa() && state.isMapped(operand.get())) { + assert(state.getOptions().allowUnknownOps && + "unsupported op error should have been emitted earlier"); b.setInsertionPoint(op); Value toTensorOp = b.create( op->getLoc(), state.lookupBuffer(operand.get())); @@ -456,6 +455,7 @@ } } + // Bufferize all regions. for (Region ®ion : op->getRegions()) if (failed(bufferize(®ion, state))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -667,10 +667,10 @@ /// Assert that the current bufferization decisions are consistent. static LogicalResult -checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo, +checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, const BufferizationAliasInfo &aliasInfo) { Operation *inconsistentOp = nullptr; - WalkResult walkResult = funcOp.walk([&](Operation *op) { + WalkResult walkResult = op->walk([&](Operation *op) { if (auto bufferizableOp = dyn_cast(op)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) { @@ -710,20 +710,23 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - FuncOp funcOp, const BufferizationOptions &options, + Operation *op, const BufferizationOptions &options) { + BufferizationState state(op, options); + PostAnalysisStepList extraSteps; + return runComprehensiveBufferize(op, options, state, extraSteps); +} + +LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( + Operation *op, const BufferizationOptions &options, BufferizationState &state, const PostAnalysisStepList &extraSteps) { - DominanceInfo domInfo(funcOp); + DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; - if (funcOp.body().empty()) - return success(); - - if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo))) + if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo))) return failure(); // If the analysis fails, just return. - Operation *op = funcOp.getOperation(); if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, options.analysisFuzzerSeed))) return failure(); @@ -732,7 +735,7 @@ auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { for (const std::unique_ptr &step : steps) { SmallVector newOps; - if (failed(step->run(funcOp, state, aliasInfo, newOps))) + if (failed(step->run(op, state, aliasInfo, newOps))) return failure(); // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) @@ -749,16 +752,12 @@ // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) { - annotateOpsWithBufferizationMarkers(funcOp, aliasInfo); + annotateOpsWithBufferizationMarkers(op, aliasInfo); return success(); } - // Bufferize all ops in funcOp. - OpBuilder b(funcOp.getContext()); - auto bufferizableOp = - dyn_cast(funcOp.getOperation()); - assert(bufferizableOp && "must use ModuleBufferization"); - if (failed(bufferizableOp.bufferize(b, state))) + // Bufferize the op and its nested ops. + if (failed(bufferize(op, state))) return failure(); // Erase all obsolete ops. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -56,6 +56,10 @@ // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + // Nothing to do. This op is already bufferized. + if (op.hasBufferSemantics()) + return success(); + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need // basis. if (!op.hasTensorSemantics()) @@ -371,21 +375,21 @@ } // namespace -/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced +/// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced /// with the the result of `rewriteFunc` if it is anchored on a matching /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single InitTensorOp. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InitTensorEliminationStep::eliminateInitTensors( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps) { - OpBuilder b(funcOp->getContext()); + OpBuilder b(op->getContext()); - WalkResult status = funcOp->walk([&](Operation *op) { + WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { // Is this a matching OpOperand? if (!anchorMatchFunc(operand)) @@ -443,7 +447,7 @@ return failure(status.wasInterrupted()); } -/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be +/// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be /// eliminated if it is eventually inserted into another tensor (and some other /// conditions are met). /// @@ -473,10 +477,10 @@ /// out-of-place due to RaW conflicts. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InsertSliceAnchoredInitTensorEliminationStep::run( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { return eliminateInitTensors( - funcOp, state, aliasInfo, + op, state, aliasInfo, [&](OpOperand &operand) { auto insertSliceOp = dyn_cast(operand.getOwner()); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -87,12 +87,13 @@ op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); } - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override { ModuleBufferizationState &moduleState = getModuleBufferizationState(state); // Support only single return-terminated block in the function. + auto funcOp = cast(op); ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -264,11 +264,11 @@ }; LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext:: - AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state, + AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { LogicalResult status = success(); - funcOp->walk([&](scf::YieldOp yieldOp) { + op->walk([&](scf::YieldOp yieldOp) { auto forOp = dyn_cast(yieldOp->getParentOp()); if (!forOp) return WalkResult::advance(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -432,11 +432,11 @@ } // namespace mlir LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext:: - InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state, + InplaceInsertSliceOpAnalysis::run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { auto &tensorState = getTensorBufferizationState(state); - funcOp.walk([&](InsertSliceOp insertSliceOp) { + op->walk([&](InsertSliceOp insertSliceOp) { // A copy of the source buffer is needed if either: // - The producer of `source` is not inplace. This is the case where a // slice is computed out of place into the inplace full tensor.