diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -113,6 +113,10 @@ // this should be exposed with an opName + filter and a proper pattern. std::unique_ptr> createLinalgStrategyDecomposePass(); +/// Create a pass to apply patterns to convert extract_slice(pad_tensor) into +/// pad_tensor(extract_slice). +std::unique_ptr> createLinalgStrategySwapSliceOfPadPass(); + /// Create a LinalgStrategyInterchangePass. std::unique_ptr> createLinalgStrategyInterchangePass(ArrayRef iteratorInterchange = {}, diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -300,6 +300,19 @@ ]; } +def LinalgStrategySwapSliceOfPadPass + : FunctionPass<"linalg-strategy-swap-slice-of-pad-pass"> { + let summary = "Configurable pass to convert extract_slice(pad_tensor) into " + "pad_tensor(extract_slice)."; + let constructor = "mlir::createLinalgStrategySwapSliceOfPadPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + + def LinalgStrategyInterchangePass : FunctionPass<"linalg-strategy-interchange-pass"> { let summary = "Configurable pass to apply pattern-based iterator interchange."; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -153,6 +153,23 @@ } }; +/// Configurable pass to apply patterns to convert extract_slice(pad_tensor) +/// into pad_tensor(extract_slice). +struct LinalgStrategySwapSliceOfPadPass + : public LinalgStrategySwapSliceOfPadPassBase< + LinalgStrategySwapSliceOfPadPass> { + void runOnFunction() override { + auto funcOp = getFunction(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + signalPassFailure(); + } +}; + /// Configurable pass to apply pattern-based linalg generalization. struct LinalgStrategyInterchangePass : public LinalgStrategyInterchangePassBase { @@ -416,6 +433,11 @@ return std::make_unique(); } +std::unique_ptr> +mlir::createLinalgStrategySwapSliceOfPadPass() { + return std::make_unique(); +} + /// Create a LinalgStrategyInterchangePass. std::unique_ptr> mlir::createLinalgStrategyInterchangePass(ArrayRef iteratorInterchange,