diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -103,14 +103,6 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyLowerVectorsPass. -std::unique_ptr> -createLinalgStrategyLowerVectorsPass( - linalg::LinalgVectorLoweringOptions opt = - linalg::LinalgVectorLoweringOptions(), - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> createLinalgStrategyRemoveMarkersPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -212,17 +212,6 @@ ]; } -def LinalgStrategyLowerVectorsPass - : Pass<"linalg-strategy-lower-vectors-pass", "func::FuncOp"> { - let summary = "Configurable pass to lower vector operations."; - let constructor = "mlir::createLinalgStrategyLowerVectorsPass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - def LinalgStrategyRemoveMarkersPass : Pass<"linalg-strategy-remove-markers-pass", "func::FuncOp"> { let summary = "Cleanup pass that drops markers."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -92,22 +92,6 @@ } }; -/// Represent one application of createLinalgStrategyLowerVectorsPass. -struct VectorLowering : public Transformation { - explicit VectorLowering( - linalg::LinalgVectorLoweringOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), options(options) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyLowerVectorsPass(options, m)); - } - -private: - linalg::LinalgVectorLoweringOptions options; -}; - /// Codegen strategy controls how a Linalg op is progressively lowered. struct CodegenStrategy { /// Append a pattern to tile the Op `opName` and fuse its producers with @@ -169,12 +153,6 @@ decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? decompose(std::move(f)) : *this; } - /// Append a pattern to lower all vector operations. - CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) { - transformationSequence.emplace_back( - std::make_unique(options)); - return *this; - } /// Configure the post staged-patterns global enabling passes options. CodegenStrategy & setVectorTransferToSCFOptions(LinalgEnablingOptions options) { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -927,96 +927,6 @@ } }; -/// Vector lowering options control how ops are lowered down to 1-D and scf.for -/// form. -struct LinalgVectorLoweringOptions { - /// Enable lowering of vector.contract. - /// In a progressive lowering of vectors, this would be the 1st step. - bool contractionLowering = false; - LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) { - contractionLowering = val; - return *this; - } - /// Enable lowering of vector.multi_reduce. - /// In a progressive lowering of vectors, this would be the 2nd step. - bool multiReductionLowering = false; - LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) { - multiReductionLowering = val; - return *this; - } - /// Trigger full / partial vector.transfer splits. - /// In a progressive lowering of vectors, this would be the 3rd step. - bool transferPartialRewrite = false; - LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) { - transferPartialRewrite = val; - return *this; - } - /// Enable lowering of vector.transfer to scf. - /// In a progressive lowering of vectors, this would be the 4th step. - bool transferToSCFConversion = false; - LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) { - transferToSCFConversion = val; - return *this; - } - /// Maximal transfer rank under which we do not lower further. - int64_t maxTransferRank = 1; - LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { - maxTransferRank = val; - return *this; - } - /// Vector lowering operations may result in surprising behavior when - /// composing multiple codegen strategies and must be enabled explicitly. - /// In a progressive lowering of vectors, this would be the 5th step. - bool transferLowering = true; - LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) { - transferLowering = val; - return *this; - } - /// Enable lowering of vector.shape_cast to insert/extract. - /// In a progressive lowering of vectors, this would be the 6th step. - bool shapeCastLowering = true; - LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) { - shapeCastLowering = val; - return *this; - } - /// Enable lowering of vector.transpose. - /// In a progressive lowering of vectors, this would be the 7th step. - bool transposeLowering = false; - LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) { - transposeLowering = val; - return *this; - } - /// Enable AVX2-specific lowerings. - bool avx2Lowering = false; - LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) { - avx2Lowering = val; - return *this; - } - - /// Configure the post staged-patterns late vector.transfer to scf - /// conversion. - VectorTransferToSCFOptions vectorTransferToSCFOptions; - LinalgVectorLoweringOptions & - setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { - vectorTransferToSCFOptions = options; - return *this; - } - /// Configure late vector transformations. - vector::VectorTransformsOptions vectorTransformOptions; - LinalgVectorLoweringOptions & - setVectorTransformsOptions(vector::VectorTransformsOptions options) { - vectorTransformOptions = options; - return *this; - } - /// Configure specialized vector lowerings. - x86vector::avx2::LoweringOptions avx2LoweringOptions; - LinalgVectorLoweringOptions & - setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) { - avx2LoweringOptions = options; - return *this; - } -}; - //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -180,71 +180,6 @@ LinalgTransformationFilter filter; }; -/// Configurable pass to lower vector operations. -struct LinalgStrategyLowerVectorsPass - : public impl::LinalgStrategyLowerVectorsPassBase< - LinalgStrategyLowerVectorsPass> { - - LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, - LinalgTransformationFilter filt) - : options(opt), filter(std::move(filt)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - MLIRContext *context = funcOp.getContext(); - RewritePatternSet patterns(context); - vector::populateVectorToVectorCanonicalizationPatterns(patterns); - // In a progressive lowering of vectors, this would be the 1st step. - if (options.contractionLowering) { - patterns.add( - options.vectorTransformOptions, context); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - } - // In a progressive lowering of vectors, this would be the 2nd step. - if (options.multiReductionLowering) { - vector::populateVectorMultiReductionLoweringPatterns( - patterns, - options.vectorTransformOptions.vectorMultiReductionLowering); - } - // In a progressive lowering of vectors, this would be the 3rd step. - if (options.transferPartialRewrite) { - patterns.add( - context, options.vectorTransformOptions); - } - // In a progressive lowering of vectors, this would be the 4th step. - if (options.transferLowering) { - vector::populateVectorTransferLoweringPatterns(patterns, - options.maxTransferRank); - } - // In a progressive lowering of vectors, this would be the 5th step. - if (options.transferToSCFConversion) { - populateVectorToSCFConversionPatterns( - patterns, options.vectorTransferToSCFOptions.setTargetRank( - options.maxTransferRank)); - } - // In a progressive lowering of vectors, this would be the 6th step. - if (options.shapeCastLowering) { - vector::populateVectorShapeCastLoweringPatterns(patterns); - } - // In a progressive lowering of vectors, this would be the 7th step. - if (options.transposeLowering) { - vector::populateVectorTransposeLoweringPatterns( - patterns, options.vectorTransformOptions); - if (options.avx2Lowering) - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, options.avx2LoweringOptions, /*benefit=*/10); - } - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } - - LinalgVectorLoweringOptions options; - LinalgTransformationFilter filter; -}; - /// Configurable pass to lower vector operations. struct LinalgStrategyRemoveMarkersPass : public impl::LinalgStrategyRemoveMarkersPassBase< @@ -294,13 +229,6 @@ return std::make_unique(filter); } -/// Create a LinalgStrategyLowerVectorsPass. -std::unique_ptr> -mlir::createLinalgStrategyLowerVectorsPass( - LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) { - return std::make_unique(opt, filter); -} - /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> mlir::createLinalgStrategyRemoveMarkersPass() { diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -235,39 +235,40 @@ } void runOnOperation() override { - RewritePatternSet patterns(&getContext()); + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); - // Test on one pattern in isolation. - // Explicitly disable shape_cast lowering. - LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() - .enableVectorTransposeLowering() - .enableShapeCastLowering(false); + vector::VectorTransformsOptions vectorTransformOptions; if (lowerToEltwise) { - options = options.setVectorTransformsOptions( - VectorTransformsOptions().setVectorTransposeLowering( - VectorTransposeLowering::EltWise)); + vectorTransformOptions = + vectorTransformOptions.setVectorTransposeLowering( + VectorTransposeLowering::EltWise); } if (lowerToFlatTranspose) { - options = options.setVectorTransformsOptions( - VectorTransformsOptions().setVectorTransposeLowering( - VectorTransposeLowering::Flat)); + vectorTransformOptions = + vectorTransformOptions.setVectorTransposeLowering( + VectorTransposeLowering::Flat); } if (lowerToShuffleTranspose) { - options = options.setVectorTransformsOptions( - VectorTransformsOptions().setVectorTransposeLowering( - VectorTransposeLowering::Shuffle)); + vectorTransformOptions = + vectorTransformOptions.setVectorTransposeLowering( + VectorTransposeLowering::Shuffle); } + vector::populateVectorTransposeLoweringPatterns(patterns, + vectorTransformOptions); + if (lowerToAvx2) { - options = options.enableAVX2Lowering().setAVX2LoweringOptions( + auto avx2LoweringOptions = x86vector::avx2::LoweringOptions().setTransposeOptions( x86vector::avx2::TransposeLoweringOptions() .lower4x8xf32() - .lower8x8xf32())); + .lower8x8xf32()); + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, avx2LoweringOptions, /*benefit=*/10); } - OpPassManager dynamicPM("func.func"); - dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); - if (failed(runPipeline(dynamicPM, getOperation()))) + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) return signalPassFailure(); } };