diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -64,6 +64,8 @@ /// on SSA use-def chains starting from function operands that are annotated /// with the 'inplaceable' attribute. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); +std::unique_ptr +createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy); /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -52,6 +52,9 @@ Option<"useAlloca", "use-alloca", "bool", /*default=*/"false", "Use stack allocations for memrefs (for testing purposes only)">, + Option<"useLinalgCopy", "use-linalg-copy", "bool", + /*default=*/"false", + "Use a copy operation implemented as a Linalg op.">, Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", /*default=*/"0", "Analyze ops in random order with a given seed (fuzzer)">, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -39,6 +39,10 @@ LinalgComprehensiveModuleBufferize( const LinalgComprehensiveModuleBufferize &p) = default; + LinalgComprehensiveModuleBufferize(bool linalgCopy) { + this->useLinalgCopy = linalgCopy; + } + void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { @@ -74,6 +78,32 @@ return allocated; } +/// Create a linalg::GenericOp version of an n-D copy that can further tile, +/// lower to loops or vectorize, unlike the current implementation of +/// memref::CopyOp. +/// Do not depend on linalg::CopyOp that is getting deprecated. +static LogicalResult createLinalgCopyOp(OpBuilder &b, Location loc, Value from, + Value to) { + auto memrefTypeFrom = from.getType().cast(); + auto memrefTypeTo = to.getType().cast(); + if (!memrefTypeFrom || !memrefTypeTo || + memrefTypeFrom.getRank() != memrefTypeTo.getRank()) + return failure(); + AffineMap id = + AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); + SmallVector iteratorTypes(memrefTypeTo.getRank(), + getParallelIteratorTypeName()); + b.create(loc, + /*inputs=*/from, + /*outputs=*/to, + /*indexingMaps=*/llvm::makeArrayRef({id, id}), + /*iteratorTypes=*/iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args.front()); + }); + return success(); +} + void LinalgComprehensiveModuleBufferize::runOnOperation() { auto options = std::make_unique(); if (useAlloca) { @@ -82,13 +112,17 @@ return success(); }; } + // TODO: atm memref::CopyOp can be 200x slower than linalg::GenericOp. + // Once this perf bug is fixed more systematically, we can revisit. + if (useLinalgCopy) + options->memCpyFn = createLinalgCopyOp; options->allowReturnMemref = allowReturnMemref; options->allowUnknownOps = allowUnknownOps; options->analysisFuzzerSeed = analysisFuzzerSeed; - options->testAnalysisOnly = testAnalysisOnly; - options->printConflicts = printConflicts; options->createDeallocs = createDeallocs; + options->printConflicts = printConflicts; + options->testAnalysisOnly = testAnalysisOnly; // Enable InitTensorOp elimination. if (initTensorElimination) { @@ -120,3 +154,8 @@ std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } + +std::unique_ptr +mlir::createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy) { + return std::make_unique(useLinalgCopy); +}