diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -367,6 +367,23 @@ LinalgLoweringType loweringType; }; +//===----------------------------------------------------------------------===// +// Support for staged pattern application. +//===----------------------------------------------------------------------===// +/// Helper function to allow applying rewrite patterns, interleaved with more +/// global transformations, in a staged fashion: +/// 1. the first stage consists of a list of OwningRewritePatternList. Each +/// OwningRewritePatternList in this list is applied once, in order. +/// 2. the second stage consists of a single OwningRewritePattern that is +/// applied greedily until convergence. +/// 3. the third stage consists of applying a lambda, generally used for +/// non-local transformation effects. This allows creating custom fused +/// transformations where patterns can be ordered and applied at a finer +/// granularity than a sequence of traditional compiler passes. +LogicalResult applyStagedPatterns( + Operation *op, ArrayRef stage1Patterns, + const OwningRewritePatternList &stage2Patterns, + llvm::function_ref stage3Lambda = nullptr); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -388,6 +388,15 @@ using PatternListT = std::vector>; public: + OwningRewritePatternList() = default; + + /// Construct a OwningRewritePatternList populated with the pattern `t` of + /// type `T`. + template + OwningRewritePatternList(T &&t) { + patterns.emplace_back(std::make_unique(t)); + } + PatternListT::iterator begin() { return patterns.begin(); } PatternListT::iterator end() { return patterns.end(); } PatternListT::const_iterator begin() const { return patterns.begin(); } @@ -399,12 +408,13 @@ //===--------------------------------------------------------------------===// /// Add an instance of each of the pattern types 'Ts' to the pattern list with - /// the given arguments. + /// the given arguments. Return a reference to `this` for chaining insertions. /// Note: ConstructorArg is necessary here to separate the two variadic lists. template > - void insert(ConstructorArg &&arg, ConstructorArgs &&... args) { + OwningRewritePatternList &insert(ConstructorArg &&arg, + ConstructorArgs &&... args) { // The following expands a call to emplace_back for each of the pattern // types 'Ts'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. @@ -412,6 +422,7 @@ using dummy = int[]; (void)dummy{ 0, (patterns.emplace_back(std::make_unique(arg, args...)), 0)...}; + return *this; } private: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -198,3 +198,24 @@ rewriter.eraseOp(op); return success(); } + +LogicalResult mlir::linalg::applyStagedPatterns( + Operation *op, ArrayRef stage1Patterns, + const OwningRewritePatternList &stage2Patterns, + llvm::function_ref stage3Lambda) { + for (const auto &patterns : stage1Patterns) { + if (!applyPatternsAndFoldGreedily(op, patterns)) { + llvm::dbgs() << "Underlying first stage rewrite did not converge"; + return failure(); + } + if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) { + llvm::dbgs() << "Underlying second stage rewrite did not converge"; + return failure(); + } + if (stage3Lambda) { + if (failed(stage3Lambda(op))) + return failure(); + } + } + return success(); +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -0,0 +1,34 @@ +// TODO: this needs a fix to land before being reactivated. +// RUN: ls +// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s +// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s + +func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { + linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} : + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]> + return +} + +// CHECK-LABEL:func @matmul +// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> +// CHECK: store {{.*}}[] : memref> +// +// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32> +// CHECK: store {{.*}}[] : memref> +// +// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32> +// CHECK: store {{.*}}[] : memref> +// +// CHECK: linalg.copy +// CHECK: linalg.copy +// CHECK: linalg.copy +// +// CHECK: vector.contract +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> +// +// CHECK: linalg.copy diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -33,6 +33,18 @@ Option testPatterns{*this, "test-patterns", llvm::cl::desc("Test a mixed set of patterns"), llvm::cl::init(false)}; + Option testMatmulToVectorPatterns1dTiling{ + *this, "test-matmul-to-vector-patterns-tile-1d", + llvm::cl::desc( + "Test a fused pass that applies patterns from matmul to vectors via " + "1-d tiling"), + llvm::cl::init(false)}; + Option testMatmulToVectorPatterns2dTiling{ + *this, "test-matmul-to-vector-patterns-tile-2d", + llvm::cl::desc( + "Test a fused pass that applies patterns from matmul to vectors via " + "2-d tiling"), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -137,10 +149,65 @@ }); } +OwningRewritePatternList +getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) { + OwningRewritePatternList patterns; + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + AffineMinOp::getCanonicalizationPatterns(patterns, context); + AffineMaxOp::getCanonicalizationPatterns(patterns, context); + AllocOp::getCanonicalizationPatterns(patterns, context); + SubViewOp::getCanonicalizationPatterns(patterns, context); + ViewOp::getCanonicalizationPatterns(patterns, context); + MatmulOp::getCanonicalizationPatterns(patterns, context); + return patterns; +} + +void fillL1TilingAndMatmulToVectorPatterns( + MLIRContext *context, StringRef startMarker, + SmallVectorImpl &patternsVector) { + patternsVector.emplace_back(LinalgTilingPattern( + context, + LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}), + LinalgMarker({startMarker}, "L1"))); + + patternsVector.emplace_back(LinalgPromotionPattern( + context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC"))); + + patternsVector.emplace_back( + LinalgVectorizationPattern(context, LinalgMarker({"VEC"}))); + patternsVector.back() + .insert, + LinalgVectorizationPattern>(context); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { - if (testPatterns) - return applyPatterns(getFunction()); + if (testPatterns) { + applyPatterns(getFunction()); + } else { + SmallVector stage1Patterns; + if (testMatmulToVectorPatterns1dTiling) { + fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START", + stage1Patterns); + } else if (testMatmulToVectorPatterns2dTiling) { + stage1Patterns.emplace_back( + LinalgTilingPattern(&getContext(), + LinalgTilingOptions() + .setTileSizes({768, 264, 768}) + .setInterchange({1, 2, 0}), + LinalgMarker({"START"}, "L2"))); + fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2", + stage1Patterns); + } + OwningRewritePatternList stage2Patterns = + getMatmulToVectorCanonicalizationPatterns(&getContext()); + applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns); + } + + // Drop the marker. + getFunction().walk([](LinalgOp op) { + op.removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); } namespace mlir {