diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -47,9 +47,6 @@ BufferizationOptions(); - // BufferizationOptions cannot be copied. - BufferizationOptions(const BufferizationOptions &other) = delete; - /// Return `true` if the op is allowed to be bufferized. bool isOpAllowed(Operation *op) const { if (!hasFilter) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -82,7 +82,7 @@ void populateBufferizationPattern(const BufferizationState &state, RewritePatternSet &patterns); -std::unique_ptr getPartialBufferizationOptions(); +BufferizationOptions getPartialBufferizationOptions(); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -20,35 +20,25 @@ class BufferizationAliasInfo; struct AnalysisBufferizationOptions; -/// PostAnalysisSteps can be registered with `BufferizationOptions` and are +/// PostAnalysisStepFns can be registered with `BufferizationOptions` and are /// executed after the analysis, but before bufferization. They can be used to -/// implement custom dialect-specific optimizations. -struct PostAnalysisStep { - virtual ~PostAnalysisStep() = default; - - /// Run the post analysis step. This function may modify the IR, but must keep - /// `aliasInfo` consistent. Newly created operations and operations that - /// should be re-analyzed must be added to `newOps`. - virtual LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) = 0; -}; +/// implement custom dialect-specific optimizations. They may modify the IR, but +/// must keep `aliasInfo` consistent. Newly created operations and operations +/// that should be re-analyzed must be added to `newOps`. +using PostAnalysisStepFn = std::function &)>; -using PostAnalysisStepList = std::vector>; +using PostAnalysisStepList = SmallVector; /// Options for analysis-enabled bufferization. struct AnalysisBufferizationOptions : public BufferizationOptions { AnalysisBufferizationOptions() = default; - // AnalysisBufferizationOptions cannot be copied. - AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete; - /// Register a "post analysis" step. Such steps are executed after the /// analysis, but before bufferization. - template - void addPostAnalysisStep(Args... args) { - postAnalysisSteps.emplace_back( - std::make_unique(std::forward(args)...)); + void addPostAnalysisStep(PostAnalysisStepFn fn) { + postAnalysisSteps.push_back(fn); } /// Registered post analysis steps. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -18,42 +18,38 @@ namespace comprehensive_bufferize { namespace linalg_ext { -struct InitTensorEliminationStep : public bufferization::PostAnalysisStep { - /// A function that matches anchor OpOperands for InitTensorOp elimination. - /// If an OpOperand is matched, the function should populate the SmallVector - /// with all values that are needed during `RewriteFn` to produce the - /// replacement value. - using AnchorMatchFn = std::function &)>; - - /// A function that rewrites matched anchors. - using RewriteFn = std::function; - - /// Try to eliminate InitTensorOps inside `op`. - /// - /// * `rewriteFunc` generates the replacement for the InitTensorOp. - /// * Only InitTensorOps that are anchored on a matching OpOperand as per - /// `anchorMatchFunc` are considered. "Anchored" means that there is a path - /// on the reverse SSA use-def chain, starting from the OpOperand and always - /// following the aliasing OpOperand, that eventually ends at a single - /// InitTensorOp. - /// * The result of `rewriteFunc` must usually be analyzed for inplacability. - /// This analysis can be skipped with `skipAnalysis`. - LogicalResult - eliminateInitTensors(Operation *op, bufferization::BufferizationState &state, - bufferization::BufferizationAliasInfo &aliasInfo, - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, - SmallVector &newOps); -}; +/// A function that matches anchor OpOperands for InitTensorOp elimination. +/// If an OpOperand is matched, the function should populate the SmallVector +/// with all values that are needed during `RewriteFn` to produce the +/// replacement value. +using AnchorMatchFn = std::function &)>; + +/// A function that rewrites matched anchors. +using RewriteFn = std::function; + +/// Try to eliminate InitTensorOps inside `op`. +/// +/// * `rewriteFunc` generates the replacement for the InitTensorOp. +/// * Only InitTensorOps that are anchored on a matching OpOperand as per +/// `anchorMatchFunc` are considered. "Anchored" means that there is a path +/// on the reverse SSA use-def chain, starting from the OpOperand and always +/// following the aliasing OpOperand, that eventually ends at a single +/// InitTensorOp. +/// * The result of `rewriteFunc` must usually be analyzed for inplacability. +/// This analysis can be skipped with `skipAnalysis`. +LogicalResult +eliminateInitTensors(Operation *op, bufferization::BufferizationState &state, + bufferization::BufferizationAliasInfo &aliasInfo, + AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, + SmallVector &newOps); /// Try to eliminate InitTensorOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). -struct InsertSliceAnchoredInitTensorEliminationStep - : public InitTensorEliminationStep { - LogicalResult run(Operation *op, bufferization::BufferizationState &state, - bufferization::BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) override; -}; +LogicalResult insertSliceAnchoredInitTensorEliminationStep( + Operation *op, bufferization::BufferizationState &state, + bufferization::BufferizationAliasInfo &aliasInfo, + SmallVector &newOps); void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h @@ -14,16 +14,21 @@ namespace mlir { class DialectRegistry; +namespace bufferization { +class BufferizationState; +class BufferizationAliasInfo; +} // namespace bufferization + namespace scf { /// Assert that yielded values of an scf.for op are aliasing their corresponding /// bbArgs. This is required because the i-th OpResult of an scf.for op is /// currently assumed to alias with the i-th iter_arg (in the absence of /// conflicts). -struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep { - LogicalResult run(Operation *op, bufferization::BufferizationState &state, - bufferization::BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) override; -}; +LogicalResult +assertScfForAliasingProperties(Operation *op, + bufferization::BufferizationState &state, + bufferization::BufferizationAliasInfo &aliasInfo, + SmallVector &newOps); void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace scf diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp @@ -29,16 +29,15 @@ } void runOnOperation() override { - std::unique_ptr options = - getPartialBufferizationOptions(); + BufferizationOptions options = getPartialBufferizationOptions(); if (constantOpOnly) { - options->addToOperationFilter(); + options.addToOperationFilter(); } else { - options->addToDialectFilter(); + options.addToDialectFilter(); } - options->bufferAlignment = alignment; + options.bufferAlignment = alignment; - if (failed(bufferizeOp(getOperation(), *options))) + if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -253,12 +253,11 @@ patterns.add(patterns.getContext(), state); } -std::unique_ptr -bufferization::getPartialBufferizationOptions() { - auto options = std::make_unique(); - options->allowReturnMemref = true; - options->allowUnknownOps = true; - options->createDeallocs = false; - options->fullyDynamicLayoutMaps = false; +BufferizationOptions bufferization::getPartialBufferizationOptions() { + BufferizationOptions options; + options.allowReturnMemref = true; + options.allowUnknownOps = true; + options.createDeallocs = false; + options.fullyDynamicLayoutMaps = false; return options; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -698,52 +698,51 @@ // aliasing values, which is stricter than needed. We can currently not check // for aliasing values because the analysis is a maybe-alias analysis and we // need a must-alias analysis here. -struct AssertDestinationPassingStyle : public PostAnalysisStep { - LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) override { - LogicalResult status = success(); - DominanceInfo domInfo(op); - op->walk([&](Operation *returnOp) { - if (!isRegionReturnLike(returnOp)) - return WalkResult::advance(); - - for (OpOperand &returnValOperand : returnOp->getOpOperands()) { - Value returnVal = returnValOperand.get(); - // Skip non-tensor values. - if (!returnVal.getType().isa()) - continue; +static LogicalResult +assertDestinationPassingStyle(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) { + LogicalResult status = success(); + DominanceInfo domInfo(op); + op->walk([&](Operation *returnOp) { + if (!isRegionReturnLike(returnOp)) + return WalkResult::advance(); - bool foundEquivValue = false; - aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { - if (auto bbArg = equivVal.dyn_cast()) { - Operation *definingOp = bbArg.getOwner()->getParentOp(); - if (definingOp->isProperAncestor(returnOp)) - foundEquivValue = true; - return; - } + for (OpOperand &returnValOperand : returnOp->getOpOperands()) { + Value returnVal = returnValOperand.get(); + // Skip non-tensor values. + if (!returnVal.getType().isa()) + continue; - Operation *definingOp = equivVal.getDefiningOp(); - if (definingOp->getBlock()->findAncestorOpInBlock( - *returnOp->getParentOp())) - // Skip ops that happen after `returnOp` and parent ops. - if (happensBefore(definingOp, returnOp, domInfo)) - foundEquivValue = true; - }); - - if (!foundEquivValue) - status = - returnOp->emitError() - << "operand #" << returnValOperand.getOperandNumber() - << " of ReturnLike op does not satisfy destination passing style"; - } + bool foundEquivValue = false; + aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { + if (auto bbArg = equivVal.dyn_cast()) { + Operation *definingOp = bbArg.getOwner()->getParentOp(); + if (definingOp->isProperAncestor(returnOp)) + foundEquivValue = true; + return; + } - return WalkResult::advance(); - }); + Operation *definingOp = equivVal.getDefiningOp(); + if (definingOp->getBlock()->findAncestorOpInBlock( + *returnOp->getParentOp())) + // Skip ops that happen after `returnOp` and parent ops. + if (happensBefore(definingOp, returnOp, domInfo)) + foundEquivValue = true; + }); + + if (!foundEquivValue) + status = + returnOp->emitError() + << "operand #" << returnValOperand.getOperandNumber() + << " of ReturnLike op does not satisfy destination passing style"; + } - return status; - } -}; + return WalkResult::advance(); + }); + + return status; +} LogicalResult bufferization::analyzeOp(Operation *op, AnalysisBufferizationState &state) { @@ -761,12 +760,11 @@ return failure(); equivalenceAnalysis(op, aliasInfo, state); - for (const std::unique_ptr &step : - options.postAnalysisSteps) { + for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) { SmallVector newOps; - if (failed(step->run(op, state, aliasInfo, newOps))) + if (failed(fn(op, state, aliasInfo, newOps))) return failure(); - // Analyze ops that were created by the PostAnalysisStep. + // Analyze ops that were created by the PostAnalysisStepFn. if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) return failure(); equivalenceAnalysis(newOps, aliasInfo, state); @@ -774,8 +772,7 @@ if (!options.allowReturnMemref) { SmallVector newOps; - if (failed( - AssertDestinationPassingStyle().run(op, state, aliasInfo, newOps))) + if (failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps))) return failure(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -524,11 +524,10 @@ /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single InitTensorOp. LogicalResult -mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep:: - eliminateInitTensors(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, - SmallVector &newOps) { +mlir::linalg::comprehensive_bufferize::linalg_ext::eliminateInitTensors( + Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, + AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, + SmallVector &newOps) { OpBuilder b(op->getContext()); WalkResult status = op->walk([&](Operation *op) { @@ -628,7 +627,7 @@ /// Note that the newly inserted ExtractSliceOp may have to bufferize /// out-of-place due to RaW conflicts. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: - InsertSliceAnchoredInitTensorEliminationStep::run( + insertSliceAnchoredInitTensorEliminationStep( Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { return eliminateInitTensors( diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -16,11 +16,12 @@ // their respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is -// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`. +// gathered through PostAnalysisStepFns and stored in +// `ModuleBufferizationState`. // -// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each +// * `equivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each // tensor return value (if any). -// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is +// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is // read/written. // // Only tensors that are equivalent to some FuncOp bbArg may be returned. @@ -47,7 +48,7 @@ // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize // out-of-place because `%t0` is modified by the callee but read by the // tensor.extract op. The analysis of CallOps decides whether an OpOperand must -// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`. +// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. // ``` // func @callee(%t1 : tensor) -> tensor { // %f = ... : f32 @@ -62,7 +63,7 @@ // } // ``` // -// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot +// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot // analyze the function body. In such a case, the CallOp analysis conservatively // assumes that each tensor OpOperand is both read and written. // @@ -159,55 +160,55 @@ } namespace { -/// Store function BlockArguments that are equivalent to a returned value in -/// ModuleBufferizationState. -struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep { - /// Annotate IR with the results of the analysis. For testing purposes only. - static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) { - const char *kEquivalentArgsAttr = "__equivalent_func_args__"; - Operation *op = returnVal.getOwner(); - - SmallVector equivBbArgs; - if (op->hasAttr(kEquivalentArgsAttr)) { - auto attr = op->getAttr(kEquivalentArgsAttr).cast(); - equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { - return a.cast().getValue().getSExtValue(); - })); - } else { - equivBbArgs.append(op->getNumOperands(), -1); - } - equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); - OpBuilder b(op->getContext()); - op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); +/// Annotate IR with the results of the analysis. For testing purposes only. +static void annotateEquivalentReturnBbArg(OpOperand &returnVal, + BlockArgument bbArg) { + const char *kEquivalentArgsAttr = "__equivalent_func_args__"; + Operation *op = returnVal.getOwner(); + + SmallVector equivBbArgs; + if (op->hasAttr(kEquivalentArgsAttr)) { + auto attr = op->getAttr(kEquivalentArgsAttr).cast(); + equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { + return a.cast().getValue().getSExtValue(); + })); + } else { + equivBbArgs.append(op->getNumOperands(), -1); } + equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); - LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) override { - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + OpBuilder b(op->getContext()); + op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); +} - // Support only single return-terminated block in the function. - auto funcOp = cast(op); - ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - for (OpOperand &returnVal : returnOp->getOpOperands()) - if (returnVal.get().getType().isa()) - for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) - if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), - bbArg)) { - moduleState - .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] = - bbArg.getArgNumber(); - if (state.getOptions().testAnalysisOnly) - annotateReturnOp(returnVal, bbArg); - } +/// Store function BlockArguments that are equivalent to a returned value in +/// ModuleBufferizationState. +static LogicalResult +equivalentFuncOpBBArgsAnalysis(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) { + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - return success(); - } -}; + // Support only single return-terminated block in the function. + auto funcOp = cast(op); + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + for (OpOperand &returnVal : returnOp->getOpOperands()) + if (returnVal.get().getType().isa()) + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) + if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { + moduleState + .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] = + bbArg.getArgNumber(); + if (state.getOptions().testAnalysisOnly) + annotateEquivalentReturnBbArg(returnVal, bbArg); + } + + return success(); +} /// Return true if the buffer of the given tensor value is written to. Must not /// be called for values inside not yet analyzed functions. (Post-analysis @@ -239,38 +240,37 @@ } /// Determine which FuncOp bbArgs are read and which are written. If this -/// PostAnalysisStep is run on a function with unknown ops, it will +/// PostAnalysisStepFn is run on a function with unknown ops, it will /// conservatively assume that such ops bufferize to a read + write. -struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep { - LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) override { - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - auto funcOp = cast(op); - - // If the function has no body, conservatively assume that all args are - // read + written. - if (funcOp.getBody().empty()) { - for (BlockArgument bbArg : funcOp.getArguments()) { - moduleState.readBbArgs.insert(bbArg); - moduleState.writtenBbArgs.insert(bbArg); - } - - return success(); - } +static LogicalResult +funcOpBbArgReadWriteAnalysis(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) { + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + auto funcOp = cast(op); + // If the function has no body, conservatively assume that all args are + // read + written. + if (funcOp.getBody().empty()) { for (BlockArgument bbArg : funcOp.getArguments()) { - if (!bbArg.getType().isa()) - continue; - if (state.isValueRead(bbArg)) - moduleState.readBbArgs.insert(bbArg); - if (isValueWritten(bbArg, state, aliasInfo)) - moduleState.writtenBbArgs.insert(bbArg); + moduleState.readBbArgs.insert(bbArg); + moduleState.writtenBbArgs.insert(bbArg); } return success(); } -}; + + for (BlockArgument bbArg : funcOp.getArguments()) { + if (!bbArg.getType().isa()) + continue; + if (state.isValueRead(bbArg)) + moduleState.readBbArgs.insert(bbArg); + if (isValueWritten(bbArg, state, aliasInfo)) + moduleState.writtenBbArgs.insert(bbArg); + } + + return success(); +} } // namespace static bool isaTensor(Type t) { return t.isa(); } @@ -983,10 +983,8 @@ return failure(); // Collect bbArg/return value information after the analysis. - options->postAnalysisSteps.emplace_back( - std::make_unique()); - options->postAnalysisSteps.emplace_back( - std::make_unique()); + options->postAnalysisSteps.push_back(equivalentFuncOpBBArgsAnalysis); + options->postAnalysisSteps.push_back(funcOpBbArgReadWriteAnalysis); // Analyze ops. for (FuncOp funcOp : moduleState.orderedFuncOps) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -125,12 +125,12 @@ // Enable InitTensorOp elimination. if (initTensorElimination) { - options->addPostAnalysisStep< - linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + options->addPostAnalysisStep( + linalg_ext::insertSliceAnchoredInitTensorEliminationStep); } // Only certain scf.for ops are supported by the analysis. - options->addPostAnalysisStep(); + options->addPostAnalysisStep(scf::assertScfForAliasingProperties); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -432,7 +432,7 @@ } // namespace scf } // namespace mlir -LogicalResult mlir::scf::AssertScfForAliasingProperties::run( +LogicalResult mlir::scf::assertScfForAliasingProperties( Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { LogicalResult status = success(); diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -30,11 +30,10 @@ namespace { struct TensorBufferizePass : public TensorBufferizeBase { void runOnOperation() override { - std::unique_ptr options = - getPartialBufferizationOptions(); - options->addToDialectFilter(); + BufferizationOptions options = getPartialBufferizationOptions(); + options.addToDialectFilter(); - if (failed(bufferizeOp(getOperation(), *options))) + if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -104,7 +104,7 @@ auto options = std::make_unique(); if (!allowReturnMemref) - options->addPostAnalysisStep(); + options->addPostAnalysisStep(scf::assertScfForAliasingProperties); options->allowReturnMemref = allowReturnMemref; options->allowUnknownOps = allowUnknownOps;