diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -48,6 +48,10 @@ /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); +/// Create a pass to conver named Linalg operations to Linalg generic +/// operations. +std::unique_ptr> createLinalgGeneralizationPass(); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -101,4 +101,10 @@ let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } +def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> { + let summary = "Convert named ops into generic ops"; + let constructor = "mlir::createLinalgGeneralizationPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -625,6 +625,13 @@ LinalgLoweringType loweringType; }; +/// Linalg generalization patterns + +/// Patterns to convert linalg.conv ops to linalg.generic ops. +void populateLinalgConvGeneralizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + LinalgMarker maker = LinalgMarker()); + //===----------------------------------------------------------------------===// // Op-specific patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ DropUnitDims.cpp Fusion.cpp FusionOnTensors.cpp + Generalization.cpp Hoisting.cpp Interchange.cpp Loops.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -0,0 +1,110 @@ +//===- Generalization.cpp - linalg named ops to generic ops --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Linalg generalization pass. It converts named +// Linalg ops to linalg.generic ops. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "linalg-generalization" + +using namespace mlir; + +namespace { + +/// Base class for all linalg generalization patterns. A subclass must provide +/// the following method: +/// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &) +/// for creating the generic op. +template +struct LinalgGeneralizationPattern : OpRewritePattern { + LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), marker(std::move(marker)) {} + + LogicalResult matchAndRewrite(RootOp rootOp, + PatternRewriter &rewriter) const override { + auto linalgOp = dyn_cast(rootOp.getOperation()); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + + auto *pattern = static_cast(this); + linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); + if (!genericOp) + return failure(); + + rewriter.replaceOp(rootOp, genericOp.getResults()); + marker.replaceLinalgMarker(rewriter, genericOp.getOperation()); + return success(); + } + +private: + linalg::LinalgMarker marker; +}; + +struct GeneralizeConvOp + : public LinalgGeneralizationPattern { + using LinalgGeneralizationPattern::LinalgGeneralizationPattern; + + linalg::GenericOp createGenericOp(linalg::ConvOp, + PatternRewriter &rewriter) const; +}; + +struct LinalgGeneralizationPass + : public LinalgGeneralizationBase { + void runOnFunction() override; +}; + +} // namespace + +linalg::GenericOp +GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp, + PatternRewriter &rewriter) const { + SmallVector indexingMaps = convOp.getIndexingMaps(); + auto iterators = + llvm::to_vector<4>(convOp.iterator_types().getAsValueRange()); + return rewriter.create( + convOp.getLoc(), /*resultTensorTypes=*/ArrayRef(), + convOp.getInputBuffers(), convOp.getOutputBuffers(), + /*initTensors=*/ValueRange(), indexingMaps, iterators, + [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { + Value mul = + bodyBuilder.create(bodyLoc, bodyArgs[0], bodyArgs[1]); + Value add = bodyBuilder.create(bodyLoc, mul, bodyArgs[2]); + bodyBuilder.create(bodyLoc, add); + }); +} + +void LinalgGeneralizationPass::runOnFunction() { + FuncOp func = getFunction(); + OwningRewritePatternList patterns; + linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns); + applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); +} + +void mlir::linalg::populateLinalgConvGeneralizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + linalg::LinalgMarker marker) { + patterns.insert(context, marker); +} + +std::unique_ptr> mlir::createLinalgGeneralizationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s + +func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) { + linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32> + return +} + +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 4 + d3 * 2, d2 * 5 + d4 * 3, d5)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> + +// CHECK: func @generalize_conv +// CHECK-SAME: %[[INPUT:.+]]: memref<1x225x225x3xf32> +// CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32> +// CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32> + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "window", "window", "reduction", "parallel"] +// CHECK-SAME: ins(%[[FILTER]], %[[INPUT]] +// CHECK-SAME: outs(%[[OUTPUT]] + +// CHECK: ^{{.*}}(%[[FILTER_ARG:.+]]: f32, %[[INPUT_ARG:.+]]: f32, %[[OUTPUT_ARG:.+]]: f32) +// CHECK: %[[MUL:.+]] = mulf %[[FILTER_ARG]], %[[INPUT_ARG]] +// CHECK: %[[ADD:.+]] = addf %[[MUL]], %[[OUTPUT_ARG]] +// CHECK: linalg.yield %[[ADD]]