diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h @@ -27,10 +27,24 @@ /// Run Module Bufferization on the given module. Performs a simple function /// call analysis to determine which function arguments are inplaceable. Then /// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize. -LogicalResult runComprehensiveBufferize( +LogicalResult runModuleBufferize( ModuleOp moduleOp, std::unique_ptr options); +/// Run Module Bufferization on the given module. Performs a simple function +/// call analysis to determine which function arguments are inplaceable. Then +/// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize. +/// +/// Note: This bufferization options that are passed to this function overload +/// must have been prepared with `prepareOptions`. +LogicalResult runModuleBufferize( + ModuleOp moduleOp, + const bufferization::AnalysisBufferizationOptions &preparedOptions); + +/// Prepare bufferization options. This enqueues all PostAnalysisSteps that are +/// required for Module Bufferization. +void prepareOptions(bufferization::AnalysisBufferizationOptions &options); + namespace std_ext { void registerModuleBufferizationExternalModels(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -18,6 +18,9 @@ #include "mlir/Pass/Pass.h" namespace mlir { +namespace bufferization { +struct AnalysisBufferizationOptions; +} // namespace bufferization std::unique_ptr createConvertElementwiseToLinalgPass(); @@ -64,8 +67,8 @@ /// on SSA use-def chains starting from function operands that are annotated /// with the 'inplaceable' attribute. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); -std::unique_ptr -createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy); +std::unique_ptr createLinalgComprehensiveModuleBufferizePass( + std::unique_ptr options); /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -52,9 +52,6 @@ Option<"useAlloca", "use-alloca", "bool", /*default=*/"false", "Use stack allocations for memrefs (for testing purposes only)">, - Option<"useLinalgCopy", "use-memref.copy", "bool", - /*default=*/"false", - "Use a copy operation implemented as a Linalg op.">, Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool", /*default=*/"true", "Generate MemRef types with dynamic offset+strides by default.">, diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -135,6 +135,11 @@ Value outputTensor, ArrayRef transposeVector); +/// Returns GenericOp that copies an n-D memref. Unlike the current +/// implementation of memref::CopyOp, this op can further tile, lower to loops +/// or vectorize. +GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -10,10 +10,10 @@ // bufferizes function boundaries. It provides `BufferizableOpInterface` // implementations for FuncOp, CallOp and ReturnOp. // -// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`. -// This function analyzed the given module and determines the order of -// analysis and bufferization: Functions that are called are processed before -// their respective callers. +// Module Bufferization is run via `runModuleBufferize(ModuleOp, ...)`. This +// function analyzes the given module and determines the order of analysis and +// bufferization: Functions that are called are processed before their +// respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`. @@ -971,10 +971,19 @@ setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); } -LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - ModuleOp moduleOp, std::unique_ptr options) { +void mlir::linalg::comprehensive_bufferize::prepareOptions( + AnalysisBufferizationOptions &options) { + // Collect bbArg/return value information after the analysis. + options.postAnalysisSteps.emplace_back( + std::make_unique()); + options.postAnalysisSteps.emplace_back( + std::make_unique()); +} + +LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( + ModuleOp moduleOp, const AnalysisBufferizationOptions &preparedOptions) { IRRewriter rewriter(moduleOp.getContext()); - AnalysisBufferizationState state(moduleOp, *options); + AnalysisBufferizationState state(moduleOp, preparedOptions); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); @@ -982,12 +991,6 @@ moduleState.callerMap))) return failure(); - // Collect bbArg/return value information after the analysis. - options->postAnalysisSteps.emplace_back( - std::make_unique()); - options->postAnalysisSteps.emplace_back( - std::make_unique()); - // Analyze ops. for (FuncOp funcOp : moduleState.orderedFuncOps) { // No body => no analysis. @@ -1009,11 +1012,11 @@ moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; // Add annotations to function arguments. - if (options->testAnalysisOnly) + if (preparedOptions.testAnalysisOnly) annotateOpsWithBufferizationMarkers(funcOp, state); } - if (options->testAnalysisOnly) + if (preparedOptions.testAnalysisOnly) return success(); // Bufferize function bodies. @@ -1033,7 +1036,7 @@ if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state))) return failure(); - if (!options->allowReturnMemref && + if (!preparedOptions.allowReturnMemref && llvm::any_of(funcOp.getType().getResults(), [](Type t) { return t.isa(); })) { @@ -1054,3 +1057,9 @@ return success(); } + +LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( + ModuleOp moduleOp, std::unique_ptr options) { + prepareOptions(*options); + return runModuleBufferize(moduleOp, *options); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -37,12 +37,13 @@ LinalgComprehensiveModuleBufferize() = default; LinalgComprehensiveModuleBufferize( - const LinalgComprehensiveModuleBufferize &p) = default; - - LinalgComprehensiveModuleBufferize(bool linalgCopy) { - this->useLinalgCopy = linalgCopy; + const LinalgComprehensiveModuleBufferize &p) { + llvm_unreachable("pass cannot be copied"); } + explicit LinalgComprehensiveModuleBufferize( + std::unique_ptr options); + void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { @@ -60,9 +61,20 @@ tensor::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } + +private: + std::unique_ptr options; }; } // namespace +LinalgComprehensiveModuleBufferize::LinalgComprehensiveModuleBufferize( + std::unique_ptr options) + : options(std::move(options)) { + // Only certain scf.for ops are supported by the analysis. + this->options->addPostAnalysisStep(); + prepareOptions(*this->options); +} + static void applyEnablingTransformations(ModuleOp moduleOp) { RewritePatternSet patterns(moduleOp.getContext()); patterns.add(moduleOp.getContext()); @@ -78,71 +90,42 @@ return allocated; } -/// Create a linalg::GenericOp version of an n-D copy that can further tile, -/// lower to loops or vectorize, unlike the current implementation of -/// memref::CopyOp. -/// Do not depend on memref::CopyOp that is getting deprecated. -static LogicalResult createLinalgCopyOp(OpBuilder &b, Location loc, Value from, - Value to) { - auto memrefTypeFrom = from.getType().cast(); - auto memrefTypeTo = to.getType().cast(); - if (!memrefTypeFrom || !memrefTypeTo || - memrefTypeFrom.getRank() != memrefTypeTo.getRank()) - return failure(); - AffineMap id = - AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); - SmallVector iteratorTypes(memrefTypeTo.getRank(), - getParallelIteratorTypeName()); - b.create(loc, - /*inputs=*/from, - /*outputs=*/to, - /*indexingMaps=*/llvm::makeArrayRef({id, id}), - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args.front()); - }); - return success(); -} - void LinalgComprehensiveModuleBufferize::runOnOperation() { - auto options = std::make_unique(); - if (useAlloca) { - options->allocationFn = allocationFnUsingAlloca; - options->deallocationFn = [](OpBuilder &b, Location loc, Value v) { - return success(); - }; - } - // TODO: atm memref::CopyOp can be 200x slower than linalg::GenericOp. - // Once this perf bug is fixed more systematically, we can revisit. - if (useLinalgCopy) - options->memCpyFn = createLinalgCopyOp; - - options->allowReturnMemref = allowReturnMemref; - options->allowUnknownOps = allowUnknownOps; - options->analysisFuzzerSeed = analysisFuzzerSeed; - options->createDeallocs = createDeallocs; - options->fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; - options->printConflicts = printConflicts; - options->testAnalysisOnly = testAnalysisOnly; - - // Enable InitTensorOp elimination. - if (initTensorElimination) { - options->addPostAnalysisStep< - linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + if (!options) { + // Make new bufferization options if none were provided when creating the + // pass. + options = std::make_unique(); + if (useAlloca) { + options->allocationFn = allocationFnUsingAlloca; + options->deallocationFn = [](OpBuilder &b, Location loc, Value v) { + return success(); + }; + } + options->allowReturnMemref = allowReturnMemref; + options->allowUnknownOps = allowUnknownOps; + options->analysisFuzzerSeed = analysisFuzzerSeed; + options->createDeallocs = createDeallocs; + options->fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; + options->printConflicts = printConflicts; + options->testAnalysisOnly = testAnalysisOnly; + if (initTensorElimination) { + options->addPostAnalysisStep< + linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + } + // Only certain scf.for ops are supported by the analysis. + options->addPostAnalysisStep(); + prepareOptions(*options); } - // Only certain scf.for ops are supported by the analysis. - options->addPostAnalysisStep(); - ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); - if (failed(runComprehensiveBufferize(moduleOp, std::move(options)))) { + if (failed(runModuleBufferize(moduleOp, *options))) { signalPassFailure(); return; } - if (testAnalysisOnly) + if (options->testAnalysisOnly) return; OpPassManager cleanupPipeline("builtin.module"); @@ -156,7 +139,8 @@ return std::make_unique(); } -std::unique_ptr -mlir::createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy) { - return std::make_unique(useLinalgCopy); +std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass( + std::unique_ptr options) { + return std::make_unique( + std::move(options)); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -423,6 +423,29 @@ return transposeOp; } +GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { + auto memrefTypeTo = to.getType().cast(); +#ifndef NDEBUG + auto memrefTypeFrom = from.getType().cast(); + assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && + "`from` and `to` memref must have the same rank"); +#endif // NDEBUG + + AffineMap id = + AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); + SmallVector iteratorTypes(memrefTypeTo.getRank(), + getParallelIteratorTypeName()); + return b.create( + loc, + /*inputs=*/from, + /*outputs=*/to, + /*indexingMaps=*/llvm::makeArrayRef({id, id}), + /*iteratorTypes=*/iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args.front()); + }); +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit(