diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -101,6 +101,12 @@ linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter()); +/// Create a LinalgStrategyInterchangePass. +std::unique_ptr> +createLinalgStrategyInterchangePass(ArrayRef iteratorInterchange = {}, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + /// Create a LinalgStrategyVectorizePass. std::unique_ptr> createLinalgStrategyVectorizePass(StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -268,6 +268,17 @@ ]; } +def LinalgStrategyInterchangePass + : FunctionPass<"linalg-strategy-interchange-pass"> { + let summary = "Configurable pass to apply pattern-based iterator interchange."; + let constructor = "mlir::createLinalgStrategyInterchangePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + def LinalgStrategyVectorizePass : FunctionPass<"linalg-strategy-vectorize-pass"> { let summary = "Configurable pass to apply pattern-based linalg vectorization."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -77,6 +77,22 @@ std::string opName; }; +/// Represent one application of createLinalgStrategyInterchangePass. +struct Interchange : public Transformation { + explicit Interchange(ArrayRef iteratorInterchange, + LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(f), iteratorInterchange(iteratorInterchange.begin(), + iteratorInterchange.end()) {} + + void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m)); + } + +private: + SmallVector iteratorInterchange; +}; + /// Represent one application of createLinalgStrategyVectorizePass. struct Vectorize : public Transformation { explicit Vectorize(linalg::LinalgVectorizationOptions options, @@ -147,6 +163,21 @@ return b ? generalize(opName, f) : *this; return *this; } + /// Append a pattern to interchange iterators. + CodegenStrategy & + interchange(ArrayRef iteratorInterchange, + LinalgTransformationFilter::FilterFunction f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(iteratorInterchange, f)); + return *this; + } + /// Conditionally append a pattern to interchange iterators. + CodegenStrategy & + interchangeIf(bool b, ArrayRef iteratorInterchange, + LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? interchange(iteratorInterchange, f) : *this; + return *this; + } /// Append a pattern to rewrite `LinalgOpType` as a vector operation. CodegenStrategy & vectorize(StringRef opName, diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -101,6 +101,37 @@ LinalgTransformationFilter filter; }; +/// Configurable pass to apply pattern-based linalg generalization. +struct LinalgStrategyInterchangePass + : public LinalgStrategyInterchangePassBase { + + LinalgStrategyInterchangePass() = default; + + LinalgStrategyInterchangePass(ArrayRef iteratorInterchange, + LinalgTransformationFilter filter) + : iteratorInterchange(iteratorInterchange.begin(), + iteratorInterchange.end()), + filter(filter) {} + + void runOnFunction() override { + auto funcOp = getFunction(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + SmallVector interchangeVector(iteratorInterchange.begin(), + iteratorInterchange.end()); + RewritePatternSet interchangePattern(funcOp.getContext()); + interchangePattern.add( + funcOp.getContext(), interchangeVector, filter); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(interchangePattern)))) + signalPassFailure(); + } + + SmallVector iteratorInterchange; + LinalgTransformationFilter filter; +}; + /// Configurable pass to apply pattern-based linalg promotion. struct LinalgStrategyPromotePass : public LinalgStrategyPromotePassBase { @@ -273,6 +304,14 @@ return std::make_unique(opName, filter); } +/// Create a LinalgStrategyInterchangePass. +std::unique_ptr> +mlir::createLinalgStrategyInterchangePass(ArrayRef iteratorInterchange, + LinalgTransformationFilter filter) { + return std::make_unique(iteratorInterchange, + filter); +} + /// Create a LinalgStrategyVectorizePass. std::unique_ptr> mlir::createLinalgStrategyVectorizePass(StringRef opName, diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -4,7 +4,7 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize iterator-interchange=0,2,1" | FileCheck %s --check-prefix=GENER // CHECK-LABEL: func @matmul( @@ -19,8 +19,10 @@ // CHECK-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32} // CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32> - // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32> - // GENER: linalg.generic + // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32> + + // GENER: linalg.generic + // GENER-SAME: iterator_types = ["parallel", "reduction", "parallel"] return } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -89,6 +89,9 @@ Option generalize{*this, "generalize", llvm::cl::desc("Generalize named operations."), llvm::cl::init(false)}; + ListOption iteratorInterchange{ + *this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Specifies the iterator interchange.")}; Option vectorize{ *this, "vectorize", llvm::cl::desc("Rewrite the linalg op as a vector operation."), @@ -148,6 +151,7 @@ .setAlignment(16) .setUseFullTileBuffersByDefault(registerPromoteFullTile)) .generalizeIf(generalize, anchorOpName) + .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName) .setEnableVectorTransferPartialRewrite(true) .setEnableVectorContractLowering(true)