Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/test/lib/Transforms/TestConvVectorization.cpp
Show All 18 Lines | |||||
using namespace mlir; | using namespace mlir; | ||||
using namespace vector; | using namespace vector; | ||||
namespace { | namespace { | ||||
/// A pass converting MLIR Linalg ops into Vector ops. | /// A pass converting MLIR Linalg ops into Vector ops. | ||||
class TestConvVectorization | class TestConvVectorization | ||||
: public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> { | : public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> { | ||||
public: | |||||
TestConvVectorization() = default; | |||||
TestConvVectorization(const TestConvVectorization &) {} | |||||
explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) { | |||||
tileSizes = tileSizesParam; | |||||
} | |||||
void runOnOperation() override; | void runOnOperation() override; | ||||
void getDependentDialects(DialectRegistry ®istry) const override { | void getDependentDialects(DialectRegistry ®istry) const override { | ||||
registry.insert<VectorDialect>(); | registry.insert<VectorDialect>(); | ||||
registry.insert<linalg::LinalgDialect>(); | registry.insert<linalg::LinalgDialect>(); | ||||
registry.insert<scf::SCFDialect>(); | registry.insert<scf::SCFDialect>(); | ||||
registry.insert<AffineDialect>(); | registry.insert<AffineDialect>(); | ||||
registry.insert<StandardOpsDialect>(); | registry.insert<StandardOpsDialect>(); | ||||
} | } | ||||
ListOption<int64_t> tileSizes{ | |||||
*this, "tile-sizes", llvm::cl::desc("Vectorization sizes."), | |||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; | |||||
}; | }; | ||||
} // namespace | } // namespace | ||||
void TestConvVectorization::runOnOperation() { | void TestConvVectorization::runOnOperation() { | ||||
MLIRContext *context = &getContext(); | MLIRContext *context = &getContext(); | ||||
ModuleOp module = getOperation(); | ModuleOp module = getOperation(); | ||||
ConversionTarget target(*context); | ConversionTarget target(*context); | ||||
target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect, | target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect, | ||||
VectorDialect>(); | VectorDialect>(); | ||||
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>(); | target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>(); | ||||
target.addLegalOp<linalg::FillOp, linalg::YieldOp>(); | target.addLegalOp<linalg::FillOp, linalg::YieldOp>(); | ||||
SmallVector<OwningRewritePatternList, 4> stage1Patterns; | SmallVector<OwningRewritePatternList, 4> stage1Patterns; | ||||
linalg::populateConvVectorizationPatterns(context, stage1Patterns); | linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); | ||||
OwningRewritePatternList stage2Patterns = | OwningRewritePatternList stage2Patterns = | ||||
linalg::getLinalgTilingCanonicalizationPatterns(context); | linalg::getLinalgTilingCanonicalizationPatterns(context); | ||||
stage2Patterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context); | stage2Patterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context); | ||||
auto stage3Transforms = [](Operation *op) { | auto stage3Transforms = [](Operation *op) { | ||||
PassManager pm(op->getContext()); | PassManager pm(op->getContext()); | ||||
pm.addPass(createLoopInvariantCodeMotionPass()); | pm.addPass(createLoopInvariantCodeMotionPass()); | ||||
▲ Show 20 Lines • Show All 58 Lines • Show Last 20 Lines |