diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -125,7 +125,8 @@ linalg::LinalgVectorizationOptions opt = linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter()); + linalg::LinalgTransformationFilter(), + bool padVectorize = false); /// Create a LinalgStrategyEnablePass. std::unique_ptr> createLinalgStrategyEnablePass( diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -321,6 +321,8 @@ "Which func op is the anchor to latch on.">, Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", "Which linalg op within the func is the anchor to latch on.">, + Option<"vectorizePadding", "vectorize-padding", "bool", "false", + "Enable vectorization of padding ops.">, ]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -112,21 +112,27 @@ /// Represent one application of createLinalgStrategyVectorizePass. struct Vectorize : public Transformation { explicit Vectorize(linalg::LinalgVectorizationOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(f), opName(), options(options) {} + LinalgTransformationFilter::FilterFunction f = nullptr, + bool padVectorize = false) + : Transformation(f), opName(), options(options), + vectorizePadding(padVectorize) {} Vectorize(StringRef name, linalg::LinalgVectorizationOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(f), opName(name), options(options) {} + LinalgTransformationFilter::FilterFunction f = nullptr, + bool padVectorize = false) + : Transformation(f), opName(name), options(options), + vectorizePadding(padVectorize) {} void addToPassPipeline(OpPassManager &pm, LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyVectorizePass(opName, options, m)); + pm.addPass(createLinalgStrategyVectorizePass(opName, options, m, + vectorizePadding)); } private: std::string opName; linalg::LinalgVectorizationOptions options; + bool vectorizePadding; }; /// Represent one application of createLinalgStrategyLowerVectorsPass. @@ -228,18 +234,20 @@ /// Append a pattern to rewrite `LinalgOpType` as a vector operation. CodegenStrategy & vectorize(StringRef opName, - LinalgTransformationFilter::FilterFunction f = nullptr) { + LinalgTransformationFilter::FilterFunction f = nullptr, + bool vectorizePadding = false) { assert(!opName.empty() && "expected an op name"); transformationSequence.emplace_back(std::make_unique( - opName, linalg::LinalgVectorizationOptions(), f)); + opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding)); return *this; } /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector /// operation. CodegenStrategy & vectorizeIf(bool b, StringRef opName, - LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? vectorize(opName, f) : *this; + LinalgTransformationFilter::FilterFunction f = nullptr, + bool vectorizePadding = false) { + return b ? vectorize(opName, f, vectorizePadding) : *this; return *this; } /// Append a pattern to lower all vector operations. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -223,9 +223,11 @@ LinalgStrategyVectorizePass() = default; LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, - LinalgTransformationFilter filt) + LinalgTransformationFilter filt, + bool padVectorize = false) : options(opt), filter(filt) { this->anchorOpName.setValue(opName.str()); + this->vectorizePadding.setValue(padVectorize); } void runOnFunction() override { @@ -247,6 +249,9 @@ vectorizationPatterns.add( funcOp.getContext(), /*benefit=*/2); + if (vectorizePadding) { + linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns); + } (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns)); } @@ -425,11 +430,11 @@ } /// Create a LinalgStrategyVectorizePass. -std::unique_ptr> -mlir::createLinalgStrategyVectorizePass(StringRef opName, - LinalgVectorizationOptions opt, - LinalgTransformationFilter filter) { - return std::make_unique(opName, opt, filter); +std::unique_ptr> mlir::createLinalgStrategyVectorizePass( + StringRef opName, LinalgVectorizationOptions opt, + LinalgTransformationFilter filter, bool padVectorize) { + return std::make_unique(opName, opt, filter, + padVectorize); } /// Create a LinalgStrategyEnablePass.