diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -108,13 +108,6 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyInterchangePass. -std::unique_ptr> -createLinalgStrategyInterchangePass( - ArrayRef iteratorInterchange = {}, - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyPeelPass. std::unique_ptr> createLinalgStrategyPeelPass( StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -235,17 +235,6 @@ ]; } -def LinalgStrategyInterchangePass - : Pass<"linalg-strategy-interchange-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based iterator interchange."; - let constructor = "mlir::createLinalgStrategyInterchangePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - def LinalgStrategyPeelPass : Pass<"linalg-strategy-peel-pass", "func::FuncOp"> { let summary = "Configurable pass to apply pattern-based linalg peeling."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -96,23 +96,6 @@ std::string opName; }; -/// Represent one application of createLinalgStrategyInterchangePass. -struct Interchange : public Transformation { - explicit Interchange(ArrayRef iteratorInterchange, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), - iteratorInterchange(iteratorInterchange.begin(), - iteratorInterchange.end()) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m)); - } - -private: - SmallVector iteratorInterchange; -}; - /// Represent one application of createLinalgStrategyDecomposePass. struct Decompose : public Transformation { explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr) @@ -249,21 +232,7 @@ generalizeIf(bool b, StringRef opName, LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? generalize(opName, std::move(f)) : *this; - } - /// Append a pattern to interchange iterators. - CodegenStrategy & - interchange(ArrayRef iteratorInterchange, - const LinalgTransformationFilter::FilterFunction &f = nullptr) { - transformationSequence.emplace_back( - std::make_unique(iteratorInterchange, f)); - return *this; - } - /// Conditionally append a pattern to interchange iterators. - CodegenStrategy & - interchangeIf(bool b, ArrayRef iteratorInterchange, - LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? interchange(iteratorInterchange, std::move(f)) : *this; - } + } /// Append patterns to decompose convolutions. CodegenStrategy & decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -828,38 +828,6 @@ LinalgTilingAndFusionOptions options; }; -/// -/// Linalg generic interchange pattern. -/// -/// Apply the `interchange` transformation on a RewriterBase. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `interchange` for more details. -struct GenericOpInterchangePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - /// GenericOp-specific constructor with an optional `filter`. - GenericOpInterchangePattern( - MLIRContext *context, ArrayRef interchangeVector, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(GenericOp op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// The interchange vector to reorder the iterators and indexing_maps dims. - SmallVector interchangeVector; -}; - /// /// Linalg generalization pattern. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -199,37 +199,6 @@ LinalgTransformationFilter filter; }; -/// Configurable pass to apply pattern-based linalg generalization. -struct LinalgStrategyInterchangePass - : public LinalgStrategyInterchangePassBase { - - LinalgStrategyInterchangePass() = default; - - LinalgStrategyInterchangePass(ArrayRef iteratorInterchange, - LinalgTransformationFilter filter) - : iteratorInterchange(iteratorInterchange.begin(), - iteratorInterchange.end()), - filter(std::move(filter)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - SmallVector interchangeVector(iteratorInterchange.begin(), - iteratorInterchange.end()); - RewritePatternSet interchangePattern(funcOp.getContext()); - interchangePattern.add( - funcOp.getContext(), interchangeVector, filter); - if (failed(applyPatternsAndFoldGreedily(funcOp, - std::move(interchangePattern)))) - signalPassFailure(); - } - - SmallVector iteratorInterchange; - LinalgTransformationFilter filter; -}; - /// Configurable pass to apply pattern-based linalg peeling. struct LinalgStrategyPeelPass : public LinalgStrategyPeelPassBase { @@ -491,15 +460,6 @@ return std::make_unique(filter); } -/// Create a LinalgStrategyInterchangePass. -std::unique_ptr> -mlir::createLinalgStrategyInterchangePass( - ArrayRef iteratorInterchange, - const LinalgTransformationFilter &filter) { - return std::make_unique(iteratorInterchange, - filter); -} - /// Create a LinalgStrategyPeelPass. std::unique_ptr> mlir::createLinalgStrategyPeelPass(StringRef opName, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -537,29 +537,6 @@ return tileLoopNest; } -/// Linalg generic interchange pattern. -mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( - MLIRContext *context, ArrayRef interchangeVector, - LinalgTransformationFilter f, PatternBenefit benefit) - : OpRewritePattern(context, benefit), filter(std::move(f)), - interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} - -FailureOr -mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( - GenericOp genericOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, genericOp))) - return failure(); - - FailureOr transformedOp = - interchangeGenericOp(rewriter, genericOp, interchangeVector); - if (failed(transformedOp)) - return failure(); - - // New filter if specified. - filter.replaceLinalgTransformationFilter(rewriter, genericOp); - return transformedOp; -} - /// Linalg generalization pattern. mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file -test-transform-dialect-interpreter | FileCheck %s // CHECK-DAG: #[[$STRIDED_1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. @@ -114,6 +114,14 @@ } return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]} + } +} // CHECK-LABEL: func @permute_generic // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]], diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -230,15 +230,7 @@ .addOpFilter()); patterns.add(ctx); - //===--------------------------------------------------------------------===// - // Linalg generic interchange pattern. - //===--------------------------------------------------------------------===// - patterns.add( - ctx, - /*interchangeVector=*/ArrayRef{1, 2, 0}, - LinalgTransformationFilter(ArrayRef{}, - StringAttr::get(ctx, "PERMUTED"))); - + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the marker.