diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -75,6 +75,11 @@ PassOptions::Option enableVLAVectorization{ *this, "enable-vla-vectorization", desc("Enable vector length agnostic vectorization"), init(false)}; + PassOptions::Option enableRuntimeLibrary{ + *this, "enable-runtime-library", + // TODO: Disable runtime library by default after feature complete. + desc("Use runtime library for manipulating sparse tensors"), init(true)}; + PassOptions::Option testBufferizationAnalysisOnly{ *this, "test-bufferization-analysis-only", desc("Run only the inplacability analysis"), init(false)}; @@ -82,7 +87,8 @@ /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { return SparsificationOptions(parallelization, vectorization, vectorLength, - enableSIMDIndex32, enableVLAVectorization); + enableSIMDIndex32, enableVLAVectorization, + enableRuntimeLibrary); } // These options must be kept in sync with `SparseTensorConversionBase`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -63,18 +63,20 @@ struct SparsificationOptions { SparsificationOptions(SparseParallelizationStrategy p, SparseVectorizationStrategy v, unsigned vl, bool e, - bool vla) + bool vla, bool rt) : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), - enableSIMDIndex32(e), enableVLAVectorization(vla) {} + enableSIMDIndex32(e), enableVLAVectorization(vla), + enableRuntimeLibrary(rt) {} SparsificationOptions() : SparsificationOptions(SparseParallelizationStrategy::kNone, SparseVectorizationStrategy::kNone, 1u, false, - false) {} + false, true /*use runtime library*/) {} SparseParallelizationStrategy parallelizationStrategy; SparseVectorizationStrategy vectorizationStrategy; unsigned vectorLength; bool enableSIMDIndex32; bool enableVLAVectorization; + bool enableRuntimeLibrary; }; /// Sets up sparsification rewriting rules with the given options. @@ -159,7 +161,7 @@ // Other rewriting rules and passes. //===----------------------------------------------------------------------===// -void populateSparseTensorRewriting(RewritePatternSet &patterns); +void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT); std::unique_ptr createDenseBufferizationPass( const bufferization::OneShotBufferizationOptions &options); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -97,7 +97,9 @@ Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false", "Enable i32 indexing into vectors (for efficiency)">, Option<"enableVLAVectorization", "enable-vla-vectorization", "bool", - "false", "Enable vector length agnostic vectorization"> + "false", "Enable vector length agnostic vectorization">, + Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", + "true", "Use runtime library for manipulating sparse tensors"> ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -47,17 +47,19 @@ vectorLength = options.vectorLength; enableSIMDIndex32 = options.enableSIMDIndex32; enableVLAVectorization = options.enableVLAVectorization; + enableRuntimeLibrary = options.enableRuntimeLibrary; } void runOnOperation() override { auto *ctx = &getContext(); // Apply pre-rewriting. RewritePatternSet prePatterns(ctx); - populateSparseTensorRewriting(prePatterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns)); // Translate strategy flags to strategy options. SparsificationOptions options(parallelization, vectorization, vectorLength, - enableSIMDIndex32, enableVLAVectorization); + enableSIMDIndex32, enableVLAVectorization, + enableRuntimeLibrary); + populateSparseTensorRewriting(prePatterns, options.enableRuntimeLibrary); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns)); // Apply sparsification and vector cleanup rewriting. RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -300,8 +300,10 @@ // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// -void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) { +void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, + bool /*enableRT*/) { patterns.add, ReshapeRewriter>(patterns.getContext()); + // If not enable RT, rewrite concatenate ops, etc here. }