diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -95,6 +95,12 @@ linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter()); +/// Create a LinalgStrategyGeneralizePass. +std::unique_ptr> +createLinalgStrategyGeneralizePass(StringRef opName = "", + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + /// Create a LinalgStrategyVectorizePass. std::unique_ptr> createLinalgStrategyVectorizePass(StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -255,6 +255,19 @@ ]; } +def LinalgStrategyGeneralizePass + : FunctionPass<"linalg-strategy-generalize-pass"> { + let summary = "Configurable pass to apply pattern-based generalization."; + let constructor = "mlir::createLinalgStrategyGeneralizePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + def LinalgStrategyVectorizePass : FunctionPass<"linalg-strategy-vectorize-pass"> { let summary = "Configurable pass to apply pattern-based linalg vectorization."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -62,6 +62,21 @@ linalg::LinalgPromotionOptions options; }; +/// Represent one application of createLinalgStrategyGeneralizePass. +struct Generalize : public Transformation { + explicit Generalize(StringRef name, + LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(f), opName(name) {} + + void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyGeneralizePass(opName, m)); + } + +private: + std::string opName; +}; + /// Represent one application of createLinalgStrategyVectorizePass. struct Vectorize : public Transformation { explicit Vectorize(linalg::LinalgVectorizationOptions options, @@ -117,6 +132,21 @@ return b ? promote(opName, options, f) : *this; return *this; } + /// Append a pattern to generalize named operations. + CodegenStrategy & + generalize(StringRef opName, + LinalgTransformationFilter::FilterFunction f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, f)); + return *this; + } + /// Conditionally append a pattern to generalize named operations. + CodegenStrategy & + generalizeIf(bool b, StringRef opName, + LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? generalize(opName, f) : *this; + return *this; + } /// Append a pattern to rewrite `LinalgOpType` as a vector operation. CodegenStrategy & vectorize(StringRef opName, diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -68,6 +68,39 @@ LinalgTransformationFilter filter; }; +/// Configurable pass to apply pattern-based linalg generalization. +struct LinalgStrategyGeneralizePass + : public LinalgStrategyGeneralizePassBase { + + LinalgStrategyGeneralizePass() = default; + + LinalgStrategyGeneralizePass(StringRef opName, + LinalgTransformationFilter filter) + : filter(filter) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnFunction() override { + auto funcOp = getFunction(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet generalizationPattern(funcOp.getContext()); + if (!anchorOpName.empty()) { + generalizationPattern.add( + anchorOpName, funcOp.getContext(), filter); + } else { + generalizationPattern.add( + funcOp.getContext(), filter); + } + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(generalizationPattern)))) + signalPassFailure(); + } + + LinalgTransformationFilter filter; +}; + /// Configurable pass to apply pattern-based linalg promotion. struct LinalgStrategyPromotePass : public LinalgStrategyPromotePassBase { @@ -233,6 +266,13 @@ return std::make_unique(opName, opt, filter); } +/// Create a LinalgStrategyGeneralizePass. +std::unique_ptr> +mlir::createLinalgStrategyGeneralizePass(StringRef opName, + LinalgTransformationFilter filter) { + return std::make_unique(opName, filter); +} + /// Create a LinalgStrategyVectorizePass. std::unique_ptr> mlir::createLinalgStrategyVectorizePass(StringRef opName, diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -4,9 +4,12 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER + // CHECK-LABEL: func @matmul( // OUTER-LABEL: func @matmul( +// GENER-LABEL: func @matmul( func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) @@ -17,6 +20,7 @@ // CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32> // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32> + // GENER: linalg.generic return } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -86,6 +86,9 @@ *this, "register-promote-full-tile-pad", llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."), llvm::cl::init(false)}; + Option generalize{*this, "generalize", + llvm::cl::desc("Generalize named operations."), + llvm::cl::init(false)}; Option vectorize{ *this, "vectorize", llvm::cl::desc("Rewrite the linalg op as a vector operation."), @@ -133,6 +136,7 @@ vector::VectorTransferSplit vectorTransferSplit) { assert(!anchorOpName.empty()); CodegenStrategy strategy; + StringRef genericOpName = GenericOp::getOperationName(); strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions) .promoteIf(promote, anchorOpName, LinalgPromotionOptions() @@ -143,7 +147,8 @@ LinalgPromotionOptions() .setAlignment(16) .setUseFullTileBuffersByDefault(registerPromoteFullTile)) - .vectorizeIf(vectorize, anchorOpName) + .generalizeIf(generalize, anchorOpName) + .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName) .setEnableVectorTransferPartialRewrite(true) .setEnableVectorContractLowering(true) .setEnableVectorToSCFConversion(true)