diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -125,6 +125,17 @@ SmallVector iteratorInterchange; }; +/// Represent one application of createLinalgStrategyDecomposePass. +struct Decompose : public Transformation { + explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(f) {} + + void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyDecomposePass(m)); + } +}; + /// Represent one application of createLinalgStrategyVectorizePass. struct Vectorize : public Transformation { explicit Vectorize(linalg::LinalgVectorizationOptions options, @@ -263,6 +274,18 @@ return b ? interchange(iteratorInterchange, f) : *this; return *this; } + /// Append patterns to decompose convolutions. + CodegenStrategy & + decompose(LinalgTransformationFilter::FilterFunction f = nullptr) { + transformationSequence.emplace_back(std::make_unique(f)); + return *this; + } + /// Conditionally append patterns to decompose convolutions. + CodegenStrategy & + decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? decompose(f) : *this; + return *this; + } /// Append a pattern to rewrite `LinalgOpType` as a vector operation. CodegenStrategy & vectorize(StringRef opName, diff --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir --- a/mlir/test/Dialect/Linalg/decompose-convolution.mlir +++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-decompose-convolution-patterns %s | FileCheck %s +// RUN: mlir-opt -test-linalg-codegen-strategy="decompose" -split-input-file %s | FileCheck %s // CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor // CHECK-SAME: (%[[INPUT:.+]]: tensor<4x1x6x3xf32>, %[[FILTER:.+]]: tensor<1x2x3x8xf32>, %[[INIT:.+]]: tensor<4x1x2x8xf32>) diff --git a/mlir/test/Dialect/Linalg/interchange.mlir b/mlir/test/Dialect/Linalg/interchange.mlir --- a/mlir/test/Dialect/Linalg/interchange.mlir +++ b/mlir/test/Dialect/Linalg/interchange.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic iterator-interchange=4,0,3,1,2" | FileCheck %s -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic iterator-interchange=4,0,3,1,2" -test-linalg-codegen-strategy="anchor-op=linalg.generic iterator-interchange=1,3,4,2,0" | FileCheck --check-prefix=CANCEL-OUT %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="iterator-interchange=4,0,3,1,2" | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="iterator-interchange=4,0,3,1,2" -test-linalg-codegen-strategy="iterator-interchange=1,3,4,2,0" | FileCheck --check-prefix=CANCEL-OUT %s #map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -112,6 +112,10 @@ ListOption iteratorInterchange{ *this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Specifies the iterator interchange.")}; + Option decompose{ + *this, "decompose", + llvm::cl::desc("Decompose convolutions to lower dimensional ones."), + llvm::cl::init(false)}; Option vectorize{ *this, "vectorize", llvm::cl::desc("Rewrite the linalg op as a vector operation."), @@ -163,7 +167,6 @@ LinalgPaddingOptions paddingOptions, vector::VectorContractLowering vectorContractLowering, vector::VectorTransferSplit vectorTransferSplit) { - assert(!anchorOpName.empty()); CodegenStrategy strategy; strategy .tileAndFuseIf(fuse && !tileSizes.empty(), anchorOpName, @@ -180,6 +183,7 @@ .setAlignment(16) .setUseFullTileBuffersByDefault(registerPromoteFullTile)) .padIf(pad, "", paddingOptions) + .decomposeIf(decompose) .generalizeIf(generalize, "") .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) .vectorizeIf(vectorize, "") diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -128,11 +128,6 @@ llvm::cl::desc("Specify the type of loops to generate: for, parallel or " "tiled_loop"), llvm::cl::init("for")}; - Option testDecomposeConvolutionPattern{ - *this, "test-decompose-convolution-patterns", - llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops " - "into low-D ones"), - llvm::cl::init(false)}; }; } // end anonymous namespace @@ -721,13 +716,6 @@ if (testTileScalarizeDynamicDims) return applyTilePattern(getFunction(), loopType, tileSizes, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); - if (testDecomposeConvolutionPattern) { - // TODO: thread all tests through LinalgStrategy passes. - OpPassManager dynamicPM("builtin.func"); - dynamicPM.addPass(createLinalgStrategyDecomposePass()); - if (failed(runPipeline(dynamicPM, getFunction()))) - return signalPassFailure(); - } } namespace mlir {