diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h @@ -27,9 +27,9 @@ /// Run Module Bufferization on the given module. Performs a simple function /// call analysis to determine which function arguments are inplaceable. Then /// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize. -LogicalResult runComprehensiveBufferize( - ModuleOp moduleOp, - std::unique_ptr options); +LogicalResult +runModuleBufferize(ModuleOp moduleOp, + bufferization::AnalysisBufferizationOptions options); namespace std_ext { diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -18,6 +18,9 @@ #include "mlir/Pass/Pass.h" namespace mlir { +namespace bufferization { +struct AnalysisBufferizationOptions; +} // namespace bufferization std::unique_ptr createConvertElementwiseToLinalgPass(); @@ -64,8 +67,8 @@ /// on SSA use-def chains starting from function operands that are annotated /// with the 'inplaceable' attribute. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); -std::unique_ptr -createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy); +std::unique_ptr createLinalgComprehensiveModuleBufferizePass( + const bufferization::AnalysisBufferizationOptions &options); /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -52,9 +52,6 @@ Option<"useAlloca", "use-alloca", "bool", /*default=*/"false", "Use stack allocations for memrefs (for testing purposes only)">, - Option<"useLinalgCopy", "use-memref.copy", "bool", - /*default=*/"false", - "Use a copy operation implemented as a Linalg op.">, Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool", /*default=*/"true", "Generate MemRef types with dynamic offset+strides by default.">, diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -135,6 +135,11 @@ Value outputTensor, ArrayRef transposeVector); +/// Returns GenericOp that copies an n-D memref. Unlike the current +/// implementation of memref::CopyOp, this op can further tile, lower to loops +/// or vectorize. +GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -10,10 +10,10 @@ // bufferizes function boundaries. It provides `BufferizableOpInterface` // implementations for FuncOp, CallOp and ReturnOp. // -// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`. -// This function analyzed the given module and determines the order of -// analysis and bufferization: Functions that are called are processed before -// their respective callers. +// Module Bufferization is run via `runModuleBufferize(ModuleOp, ...)`. This +// function analyzes the given module and determines the order of analysis and +// bufferization: Functions that are called are processed before their +// respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered through PostAnalysisStepFns and stored in @@ -971,10 +971,10 @@ setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); } -LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - ModuleOp moduleOp, std::unique_ptr options) { +LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( + ModuleOp moduleOp, AnalysisBufferizationOptions options) { IRRewriter rewriter(moduleOp.getContext()); - AnalysisBufferizationState state(moduleOp, *options); + AnalysisBufferizationState state(moduleOp, options); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); @@ -983,8 +983,8 @@ return failure(); // Collect bbArg/return value information after the analysis. - options->postAnalysisSteps.push_back(equivalentFuncOpBBArgsAnalysis); - options->postAnalysisSteps.push_back(funcOpBbArgReadWriteAnalysis); + options.addPostAnalysisStep(equivalentFuncOpBBArgsAnalysis); + options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); // Analyze ops. for (FuncOp funcOp : moduleState.orderedFuncOps) { @@ -1007,11 +1007,11 @@ moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; // Add annotations to function arguments. - if (options->testAnalysisOnly) + if (options.testAnalysisOnly) annotateOpsWithBufferizationMarkers(funcOp, state); } - if (options->testAnalysisOnly) + if (options.testAnalysisOnly) return success(); // Bufferize function bodies. @@ -1031,7 +1031,7 @@ if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state))) return failure(); - if (!options->allowReturnMemref && + if (!options.allowReturnMemref && llvm::any_of(funcOp.getType().getResults(), [](Type t) { return t.isa(); })) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -38,9 +38,9 @@ LinalgComprehensiveModuleBufferize( const LinalgComprehensiveModuleBufferize &p) = default; - LinalgComprehensiveModuleBufferize(bool linalgCopy) { - this->useLinalgCopy = linalgCopy; - } + explicit LinalgComprehensiveModuleBufferize( + AnalysisBufferizationOptions options) + : options(options) {} void runOnOperation() override; @@ -58,6 +58,9 @@ tensor::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } + +private: + llvm::Optional options; }; } // namespace @@ -76,71 +79,44 @@ return allocated; } -/// Create a linalg::GenericOp version of an n-D copy that can further tile, -/// lower to loops or vectorize, unlike the current implementation of -/// memref::CopyOp. -/// Do not depend on memref::CopyOp that is getting deprecated. -static LogicalResult createLinalgCopyOp(OpBuilder &b, Location loc, Value from, - Value to) { - auto memrefTypeFrom = from.getType().cast(); - auto memrefTypeTo = to.getType().cast(); - if (!memrefTypeFrom || !memrefTypeTo || - memrefTypeFrom.getRank() != memrefTypeTo.getRank()) - return failure(); - AffineMap id = - AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); - SmallVector iteratorTypes(memrefTypeTo.getRank(), - getParallelIteratorTypeName()); - b.create(loc, - /*inputs=*/from, - /*outputs=*/to, - /*indexingMaps=*/llvm::makeArrayRef({id, id}), - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args.front()); - }); - return success(); -} - void LinalgComprehensiveModuleBufferize::runOnOperation() { - auto options = std::make_unique(); - if (useAlloca) { - options->allocationFn = allocationFnUsingAlloca; - options->deallocationFn = [](OpBuilder &b, Location loc, Value v) { - return success(); - }; - } - // TODO: atm memref::CopyOp can be 200x slower than linalg::GenericOp. - // Once this perf bug is fixed more systematically, we can revisit. - if (useLinalgCopy) - options->memCpyFn = createLinalgCopyOp; - - options->allowReturnMemref = allowReturnMemref; - options->allowUnknownOps = allowUnknownOps; - options->analysisFuzzerSeed = analysisFuzzerSeed; - options->createDeallocs = createDeallocs; - options->fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; - options->printConflicts = printConflicts; - options->testAnalysisOnly = testAnalysisOnly; - - // Enable InitTensorOp elimination. - if (initTensorElimination) { - options->addPostAnalysisStep( - linalg_ext::insertSliceAnchoredInitTensorEliminationStep); + AnalysisBufferizationOptions opt; + if (!options) { + // Make new bufferization options if none were provided when creating the + // pass. + if (useAlloca) { + opt.allocationFn = allocationFnUsingAlloca; + opt.deallocationFn = [](OpBuilder &b, Location loc, Value v) { + return success(); + }; + } + opt.allowReturnMemref = allowReturnMemref; + opt.allowUnknownOps = allowUnknownOps; + opt.analysisFuzzerSeed = analysisFuzzerSeed; + opt.createDeallocs = createDeallocs; + opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; + opt.printConflicts = printConflicts; + opt.testAnalysisOnly = testAnalysisOnly; + if (initTensorElimination) { + opt.addPostAnalysisStep( + linalg_ext::insertSliceAnchoredInitTensorEliminationStep); + } + } else { + opt = *options; } // Only certain scf.for ops are supported by the analysis. - options->addPostAnalysisStep(scf::assertScfForAliasingProperties); + opt.addPostAnalysisStep(scf::assertScfForAliasingProperties); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); - if (failed(runComprehensiveBufferize(moduleOp, std::move(options)))) { + if (failed(runModuleBufferize(moduleOp, opt))) { signalPassFailure(); return; } - if (testAnalysisOnly) + if (opt.testAnalysisOnly) return; OpPassManager cleanupPipeline("builtin.module"); @@ -154,7 +130,7 @@ return std::make_unique(); } -std::unique_ptr -mlir::createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy) { - return std::make_unique(useLinalgCopy); +std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass( + const AnalysisBufferizationOptions &options) { + return std::make_unique(options); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -423,6 +423,29 @@ return transposeOp; } +GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { + auto memrefTypeTo = to.getType().cast(); +#ifndef NDEBUG + auto memrefTypeFrom = from.getType().cast(); + assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && + "`from` and `to` memref must have the same rank"); +#endif // NDEBUG + + AffineMap id = + AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); + SmallVector iteratorTypes(memrefTypeTo.getRank(), + getParallelIteratorTypeName()); + return b.create( + loc, + /*inputs=*/from, + /*outputs=*/to, + /*indexingMaps=*/llvm::makeArrayRef({id, id}), + /*iteratorTypes=*/iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args.front()); + }); +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit(