Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
Show First 20 Lines • Show All 128 Lines • ▼ Show 20 Lines | void runOnFunction() override { | ||||
if (failed(applyPatternsAndFoldGreedily(funcOp, | if (failed(applyPatternsAndFoldGreedily(funcOp, | ||||
std::move(generalizationPattern)))) | std::move(generalizationPattern)))) | ||||
signalPassFailure(); | signalPassFailure(); | ||||
} | } | ||||
LinalgTransformationFilter filter; | LinalgTransformationFilter filter; | ||||
}; | }; | ||||
/// Configurable pass to apply lowering of coarser-grained named linalg ops into | |||||
/// finer-grained named versions. | |||||
struct LinalgStrategyDecomposePass | |||||
: public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> { | |||||
LinalgStrategyDecomposePass() = default; | |||||
void runOnFunction() override { | |||||
auto funcOp = getFunction(); | |||||
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) | |||||
return; | |||||
RewritePatternSet decompositionPattern(funcOp.getContext()); | |||||
populateDecomposeConvolutionPatterns(decompositionPattern); | |||||
if (failed(applyPatternsAndFoldGreedily(funcOp, | |||||
std::move(decompositionPattern)))) | |||||
signalPassFailure(); | |||||
} | |||||
}; | |||||
/// Configurable pass to apply pattern-based linalg generalization. | /// Configurable pass to apply pattern-based linalg generalization. | ||||
struct LinalgStrategyInterchangePass | struct LinalgStrategyInterchangePass | ||||
: public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { | : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> { | ||||
LinalgStrategyInterchangePass() = default; | LinalgStrategyInterchangePass() = default; | ||||
LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, | LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, | ||||
LinalgTransformationFilter filter) | LinalgTransformationFilter filter) | ||||
▲ Show 20 Lines • Show All 239 Lines • ▼ Show 20 Lines | |||||
} | } | ||||
/// Create a LinalgStrategyGeneralizePass. | /// Create a LinalgStrategyGeneralizePass. | ||||
std::unique_ptr<OperationPass<FuncOp>> | std::unique_ptr<OperationPass<FuncOp>> | ||||
mlir::createLinalgStrategyGeneralizePass(StringRef opName, | mlir::createLinalgStrategyGeneralizePass(StringRef opName, | ||||
LinalgTransformationFilter filter) { | LinalgTransformationFilter filter) { | ||||
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); | return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter); | ||||
} | } | ||||
/// Create a LinalgStrategyDecomposePass. | |||||
// TODO: atm this is applied to all supported ops. If/when we need finer control | |||||
// this should be exposed with an opName + filter and a proper pattern. | |||||
std::unique_ptr<OperationPass<FuncOp>> | |||||
mlir::createLinalgStrategyDecomposePass() { | |||||
return std::make_unique<LinalgStrategyDecomposePass>(); | |||||
} | |||||
/// Create a LinalgStrategyInterchangePass. | /// Create a LinalgStrategyInterchangePass. | ||||
std::unique_ptr<OperationPass<FuncOp>> | std::unique_ptr<OperationPass<FuncOp>> | ||||
mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, | mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange, | ||||
LinalgTransformationFilter filter) { | LinalgTransformationFilter filter) { | ||||
return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, | return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange, | ||||
filter); | filter); | ||||
} | } | ||||
Show All 28 Lines |