diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -119,10 +119,25 @@ /// memory hierarchy. std::unique_ptr> createPipelineDataTransferPass(); +/// Options for the supervectorizer. It should be kept in sync with the options +/// defined for the pass. +struct SuperVectorizeOptions { + SuperVectorizeOptions(ArrayRef virtualVectorSize, + bool vectorizeReductions = false, + ArrayRef fastestVaryingPattern = {}) + : virtualVectorSize{virtualVectorSize}, + vectorizeReductions{vectorizeReductions}, fastestVaryingPattern{ + fastestVaryingPattern} {} + + ArrayRef virtualVectorSize; + bool vectorizeReductions{false}; + ArrayRef fastestVaryingPattern; +}; + /// Creates a pass to vectorize loops, operations and data types using a /// target-independent, n-D super-vector abstraction. std::unique_ptr> -createSuperVectorizePass(ArrayRef virtualVectorSize); +createSuperVectorizePass(const SuperVectorizeOptions &options); /// Overload relying on pass options for initialization. std::unique_ptr> createSuperVectorizePass(); diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -617,14 +617,16 @@ /// Command line arguments are preempted by non-empty pass arguments. struct Vectorize : public impl::AffineVectorizeBase { Vectorize() = default; - Vectorize(ArrayRef virtualVectorSize); + Vectorize(const SuperVectorizeOptions& options); void runOnOperation() override; }; } // namespace -Vectorize::Vectorize(ArrayRef virtualVectorSize) { - vectorSizes = virtualVectorSize; +Vectorize::Vectorize(const SuperVectorizeOptions& options) { + vectorSizes = options.virtualVectorSize; + vectorizeReductions = options.vectorizeReductions; + fastestVaryingPattern = options.fastestVaryingPattern; } static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern, @@ -1717,14 +1719,6 @@ LLVM_DEBUG(dbgs() << "\n"); } -std::unique_ptr> -createSuperVectorizePass(ArrayRef virtualVectorSize) { - return std::make_unique(virtualVectorSize); -} -std::unique_ptr> createSuperVectorizePass() { - return std::make_unique(); -} - /// Applies vectorization to the current function by searching over a bunch of /// predetermined patterns. void Vectorize::runOnOperation() { @@ -1873,8 +1867,8 @@ } std::unique_ptr> -createSuperVectorizePass(ArrayRef virtualVectorSize) { - return std::make_unique(virtualVectorSize); +createSuperVectorizePass(const SuperVectorizeOptions& options) { + return std::make_unique(options); } std::unique_ptr> createSuperVectorizePass() { return std::make_unique();