diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -159,6 +159,10 @@ Value getSource() { return input();} Value getTarget() { return output(); } + + static std::function getRegionBuilder() { + return nullptr; + } }]; let verifier = [{ return ::verify(*this); }]; @@ -188,6 +192,10 @@ return Builder(getContext()).getAffineMapArrayAttr({ extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } + + static std::function getRegionBuilder() { + return nullptr; + } }]; let verifier = [{ return ::verify(*this); }]; @@ -261,6 +269,10 @@ if (!padding().hasValue()) return 0; return padding().getValue().getValue({i, 1}); } + + static std::function getRegionBuilder() { + return nullptr; + } }]; } @@ -516,6 +528,10 @@ return ss.hasValue() ? llvm::Optional(ss.getValue()) : llvm::None; } + + static std::function getRegionBuilder() { + return nullptr; + } }]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -803,6 +803,17 @@ &res->getRegion(ridx), map); return res; }] + >, + StaticInterfaceMethod< + /*desc=*/[{ + Returns the region builder for constructing the body for linalg.generic. + Returns a null function if this named op does not define a region + builder. + }], + /*retTy=*/"std::function", + /*methodName=*/"getRegionBuilder", + (ins), + [{ return ConcreteOp::getRegionBuilder(); }] > ]; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -55,6 +55,10 @@ void populateElementwiseToLinalgConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx); +/// Create a pass to conver named Linalg operations to Linalg generic +/// operations. +std::unique_ptr> createLinalgGeneralizationPass(); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -112,4 +112,10 @@ let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } +def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> { + let summary = "Convert named ops into generic ops"; + let constructor = "mlir::createLinalgGeneralizationPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -624,6 +624,20 @@ LinalgLoweringType loweringType; }; +/// Linalg generalization patterns + +/// Populates `patterns` with patterns to convert spec-generated named ops to +/// linalg.generic ops. +void populateLinalgNamedOpsGeneralizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + LinalgMarker marker = LinalgMarker()); + +/// Populates `patterns` with patterns to convert linalg.conv ops to +/// linalg.generic ops. +void populateLinalgConvGeneralizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + LinalgMarker marker = LinalgMarker()); + //===----------------------------------------------------------------------===// // Op-specific patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ ElementwiseToLinalg.cpp Fusion.cpp FusionOnTensors.cpp + Generalization.cpp Hoisting.cpp Interchange.cpp Loops.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -0,0 +1,181 @@ +//===- Generalization.cpp - linalg named ops to generic ops --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Linalg generalization pass. It converts named +// Linalg ops to linalg.generic ops. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-generalization" + +using namespace mlir; + +// Creates a linalg.generic op from the given `namedOp`. Returns a null op if +// the given `namedOp` does not have a region builder. +static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp, + OpBuilder &builder) { + auto regionBuilder = namedOp.getRegionBuilder(); + if (!regionBuilder) { + LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n"); + return nullptr; + } + + SmallVector indexingMaps = namedOp.getIndexingMaps(); + auto iterators = llvm::to_vector<4>( + namedOp.iterator_types().getAsValueRange()); + auto resultTypes = namedOp.getOutputTensorTypes(); + SmallVector types(resultTypes.begin(), resultTypes.end()); + + return builder.create( + namedOp.getLoc(), types, namedOp.getInputBuffers(), + namedOp.getOutputBuffers(), namedOp.getInitTensors(), indexingMaps, + iterators, + [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { + edsc::ScopedContext scope(bodyBuilder, loc); + regionBuilder(*bodyBuilder.getBlock()); + }); +} + +namespace { + +/// Base class for all linalg generalization patterns. A subclass must provide +/// the following method: +/// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &) +/// for creating the generic op. +// TODO: remove this pattern after migrating all manually-written named opsi +// into auto-generated ones. +template +struct LinalgGeneralizationPattern : OpRewritePattern { + LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), marker(std::move(marker)) {} + + LogicalResult matchAndRewrite(RootOp rootOp, + PatternRewriter &rewriter) const override { + auto linalgOp = dyn_cast(rootOp.getOperation()); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + + auto *pattern = static_cast(this); + linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); + if (!genericOp) + return failure(); + + rewriter.replaceOp(rootOp, genericOp.getResults()); + marker.replaceLinalgMarker(rewriter, genericOp.getOperation()); + return success(); + } + +private: + linalg::LinalgMarker marker; +}; + +struct GeneralizeConvOp + : public LinalgGeneralizationPattern { + using LinalgGeneralizationPattern::LinalgGeneralizationPattern; + + linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const; +}; + +/// Catch-all pattern for converting all named ops with a region builder into +/// linalg.generic. +struct LinalgNamedOpGeneralizationPattern : RewritePattern { + LinalgNamedOpGeneralizationPattern(MLIRContext *context, + linalg::LinalgMarker marker, + PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()), + marker(std::move(marker)) {} + + LogicalResult matchAndRewrite(Operation *rootOp, + PatternRewriter &rewriter) const override { + auto linalgOp = dyn_cast(rootOp); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + + // No nothing to do for linalg.generic and linalg.indexed_generic. + if (isa(rootOp)) + return failure(); + + linalg::GenericOp genericOp = + createGenericOpFromNamedOp(linalgOp, rewriter); + if (!genericOp) + return failure(); + + rewriter.replaceOp(rootOp, genericOp.getResults()); + marker.replaceLinalgMarker(rewriter, genericOp.getOperation()); + return success(); + } + +private: + linalg::LinalgMarker marker; +}; + +struct LinalgGeneralizationPass + : public LinalgGeneralizationBase { + void runOnFunction() override; +}; + +} // namespace + +void LinalgGeneralizationPass::runOnFunction() { + FuncOp func = getFunction(); + OwningRewritePatternList patterns; + linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns); + linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns); + applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); +} + +linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp, + OpBuilder &builder) const { + SmallVector indexingMaps = convOp.getIndexingMaps(); + auto iterators = + llvm::to_vector<4>(convOp.iterator_types().getAsValueRange()); + return builder.create( + convOp.getLoc(), /*resultTensorTypes=*/ArrayRef(), + convOp.getInputBuffers(), convOp.getOutputBuffers(), + /*initTensors=*/ValueRange(), indexingMaps, iterators, + [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { + Value mul = + bodyBuilder.create(bodyLoc, bodyArgs[0], bodyArgs[1]); + Value add = bodyBuilder.create(bodyLoc, mul, bodyArgs[2]); + bodyBuilder.create(bodyLoc, add); + }); +} + +void mlir::linalg::populateLinalgConvGeneralizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + linalg::LinalgMarker marker) { + patterns.insert(context, marker); +} + +void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + linalg::LinalgMarker marker) { + patterns.insert(context, marker); +} + +std::unique_ptr> mlir::createLinalgGeneralizationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s + +func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) { + linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32> + return +} + +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 4 + d3 * 2, d2 * 5 + d4 * 3, d5)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> + +// CHECK: func @generalize_conv +// CHECK-SAME: %[[INPUT:.+]]: memref<1x225x225x3xf32> +// CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32> +// CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32> + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "window", "window", "reduction", "parallel"] +// CHECK-SAME: ins(%[[FILTER]], %[[INPUT]] +// CHECK-SAME: outs(%[[OUTPUT]] + +// CHECK: ^{{.*}}(%[[FILTER_ARG:.+]]: f32, %[[INPUT_ARG:.+]]: f32, %[[OUTPUT_ARG:.+]]: f32) +// CHECK: %[[MUL:.+]] = mulf %[[FILTER_ARG]], %[[INPUT_ARG]] +// CHECK: %[[ADD:.+]] = addf %[[MUL]], %[[OUTPUT_ARG]] +// CHECK: linalg.yield %[[ADD]] + +// ----- + +func @generalize_matmul(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) { + linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) outs(%C: memref<16x32xf32>) + return +} + + +// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: func @generalize_matmul +// CHECK-SAME: %[[A:.+]]: memref<16x8xf32> +// CHECK-SAME: %[[B:.+]]: memref<8x32xf32> +// CHECK-SAME: %[[C:.+]]: memref<16x32xf32> + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1522,6 +1522,7 @@ ArrayAttr iterator_types(); ArrayAttr indexing_maps(); static void regionBuilder(Block &block); + static std::function getRegionBuilder() {{ return regionBuilder; } // Generic methods. static unsigned getNumRegionArgs() {{ return {4}; }