diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -367,6 +367,23 @@ LinalgLoweringType loweringType; }; +//===----------------------------------------------------------------------===// +// Support for staged pattern application. +//===----------------------------------------------------------------------===// +/// Helper function to allow applying rewrite patterns, interleaved with more +/// global transformations, in a staged fashion: +/// 1. the first stage consists of an OwningRewritePatternList. The +/// RewritePattern in this list are applied once and in order. +/// 2. the second stage consists of a single OwningRewritePattern that is +/// applied greedily until convergence. +/// 3. the third stage consists of applying a lambda, generally used for +/// non-local transformation effects. This allows creating custom fused +/// transformations where patterns can be ordered and applied at a finer +/// granularity than a sequence of traditional compiler passes. +void applyStagedPatterns( + Operation *op, SmallVector &stage1Patterns, + OwningRewritePatternList &stage2Patterns, + std::function stage3Lambda = nullptr); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -399,12 +399,13 @@ //===--------------------------------------------------------------------===// /// Add an instance of each of the pattern types 'Ts' to the pattern list with - /// the given arguments. + /// the given arguments. Return a reference to `this` for chaining insertions. /// Note: ConstructorArg is necessary here to separate the two variadic lists. template > - void insert(ConstructorArg &&arg, ConstructorArgs &&... args) { + OwningRewritePatternList &insert(ConstructorArg &&arg, + ConstructorArgs &&... args) { // The following expands a call to emplace_back for each of the pattern // types 'Ts'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. @@ -412,6 +413,7 @@ using dummy = int[]; (void)dummy{ 0, (patterns.emplace_back(std::make_unique(arg, args...)), 0)...}; + return *this; } private: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -198,3 +198,17 @@ rewriter.eraseOp(op); return success(); } + +void mlir::linalg::applyStagedPatterns( + Operation *op, SmallVector &stage1Patterns, + OwningRewritePatternList &stage2Patterns, + std::function stage3Lambda) { + for (auto &patterns : stage1Patterns) { + if (!applyPatternsAndFoldGreedily(op, patterns)) + llvm::dbgs() << "Underlying pattern rewrite did not converge"; + if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) + llvm::dbgs() << "Underlying pattern rewrite did not converge"; + if (stage3Lambda) + stage3Lambda(op); + } +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s + +func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { + linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} : + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]> + return +} + +// CHECK-LABEL:func @matmul +// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> +// CHECK: store {{.*}}[] : memref> +// +// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32> +// CHECK: store {{.*}}[] : memref> +// +// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32> +// CHECK: store {{.*}}[] : memref> +// +// CHECK: linalg.copy +// CHECK: linalg.copy +// CHECK: linalg.copy +// +// CHECK: vector.contract +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> +// +// CHECK: linalg.copy diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -33,6 +33,18 @@ Option testPatterns{*this, "test-patterns", llvm::cl::desc("Test a mixed set of patterns"), llvm::cl::init(false)}; + Option testMatmulToVectorPatterns1dTiling{ + *this, "test-matmul-to-vector-patterns-tile-1d", + llvm::cl::desc( + "Test a fused pass that applies patterns from matmul to vectors via " + "1-d tiling"), + llvm::cl::init(false)}; + Option testMatmulToVectorPatterns2dTiling{ + *this, "test-matmul-to-vector-patterns-tile-2d", + llvm::cl::desc( + "Test a fused pass that applies patterns from matmul to vectors via " + "2-d tiling"), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -137,10 +149,71 @@ }); } +OwningRewritePatternList +getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) { + OwningRewritePatternList patterns; + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + AffineMinOp::getCanonicalizationPatterns(patterns, context); + AffineMaxOp::getCanonicalizationPatterns(patterns, context); + AllocOp::getCanonicalizationPatterns(patterns, context); + SubViewOp::getCanonicalizationPatterns(patterns, context); + ViewOp::getCanonicalizationPatterns(patterns, context); + MatmulOp::getCanonicalizationPatterns(patterns, context); + return patterns; +} + +void fillL1TilingAndMatmulToVectorPatterns( + MLIRContext *context, StringRef startMarker, + SmallVectorImpl &patternsVector) { + patternsVector.emplace_back(std::move( + OwningRewritePatternList().insert>( + context, + LinalgTilingOptions() + .setTileSizes({8, 12, 16}) + .setInterchange({1, 0, 2}), + LinalgMarker({startMarker}, "L1")))); + + patternsVector.emplace_back(std::move( + OwningRewritePatternList().insert>( + context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")))); + + patternsVector.emplace_back( + std::move(OwningRewritePatternList() + .insert>( + context, LinalgMarker({"VEC"})) + .insert, + LinalgVectorizationPattern>(context))); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { - if (testPatterns) - return applyPatterns(getFunction()); + if (testPatterns) { + applyPatterns(getFunction()); + } else { + SmallVector stage1Patterns; + if (testMatmulToVectorPatterns1dTiling) { + fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START", + stage1Patterns); + } else if (testMatmulToVectorPatterns2dTiling) { + stage1Patterns.emplace_back(std::move( + OwningRewritePatternList().insert>( + &getContext(), + LinalgTilingOptions() + .setTileSizes({768, 264, 768}) + .setInterchange({1, 2, 0}), + LinalgMarker({"START"}, "L2")))); + fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2", + stage1Patterns); + } + OwningRewritePatternList stage2Patterns = + getMatmulToVectorCanonicalizationPatterns(&getContext()); + applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns); + } + + // Drop the marker. + getFunction().walk([](LinalgOp op) { + op.removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); } namespace mlir {