diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -75,6 +75,12 @@ PassOptions::Option enableVLAVectorization{ *this, "enable-vla-vectorization", desc("Enable vector length agnostic vectorization"), init(false)}; + PassOptions::Option enableRuntimeLibrary{ + *this, "enable-runtime-library", + desc("Enable runtime library for manipulating sparse tensors"), + // TODO: Disable runtime library by default after feature complete. + init(true)}; + PassOptions::Option testBufferizationAnalysisOnly{ *this, "test-bufferization-analysis-only", desc("Run only the inplacability analysis"), init(false)}; @@ -82,7 +88,8 @@ /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { return SparsificationOptions(parallelization, vectorization, vectorLength, - enableSIMDIndex32, enableVLAVectorization); + enableSIMDIndex32, enableVLAVectorization, + enableRuntimeLibrary); } // These options must be kept in sync with `SparseTensorConversionBase`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -63,18 +63,22 @@ struct SparsificationOptions { SparsificationOptions(SparseParallelizationStrategy p, SparseVectorizationStrategy v, unsigned vl, bool e, - bool vla) + bool vla, bool rt) : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), - enableSIMDIndex32(e), enableVLAVectorization(vla) {} + enableSIMDIndex32(e), enableVLAVectorization(vla), + enableRuntimeLibrary(rt) {} SparsificationOptions() : SparsificationOptions(SparseParallelizationStrategy::kNone, - SparseVectorizationStrategy::kNone, 1u, false, - false) {} + SparseVectorizationStrategy::kNone, 1u, + /*enable SIMD Index32=*/false, + /*enable VLA Vectorization=*/false, + /*enable runtime library=*/true) {} SparseParallelizationStrategy parallelizationStrategy; SparseVectorizationStrategy vectorizationStrategy; unsigned vectorLength; bool enableSIMDIndex32; bool enableVLAVectorization; + bool enableRuntimeLibrary; }; /// Sets up sparsification rewriting rules with the given options. @@ -159,7 +163,7 @@ // Other rewriting rules and passes. //===----------------------------------------------------------------------===// -void populateSparseTensorRewriting(RewritePatternSet &patterns); +void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT); std::unique_ptr createDenseBufferizationPass( const bufferization::OneShotBufferizationOptions &options); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -97,7 +97,9 @@ Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false", "Enable i32 indexing into vectors (for efficiency)">, Option<"enableVLAVectorization", "enable-vla-vectorization", "bool", - "false", "Enable vector length agnostic vectorization"> + "false", "Enable vector length agnostic vectorization">, + Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", + "true", "Enable runtime library for manipulating sparse tensors"> ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; using namespace mlir::sparse_tensor; @@ -58,8 +59,12 @@ if (options.testBufferizationAnalysisOnly) return; pm.addPass(createSparsificationPass(options.sparsificationOptions())); - pm.addPass(createSparseTensorConversionPass( - options.sparseTensorConversionOptions())); + if (options.enableRuntimeLibrary) + pm.addPass(createSparseTensorConversionPass( + options.sparseTensorConversionOptions())); + else + pm.addPass(createSparseTensorCodegenPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addPass(createDenseBufferizationPass( getBufferizationOptions(/*analysisOnly=*/false))); pm.addNestedPass( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -47,17 +47,19 @@ vectorLength = options.vectorLength; enableSIMDIndex32 = options.enableSIMDIndex32; enableVLAVectorization = options.enableVLAVectorization; + enableRuntimeLibrary = options.enableRuntimeLibrary; } void runOnOperation() override { auto *ctx = &getContext(); - // Apply pre-rewriting. RewritePatternSet prePatterns(ctx); - populateSparseTensorRewriting(prePatterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns)); // Translate strategy flags to strategy options. SparsificationOptions options(parallelization, vectorization, vectorLength, - enableSIMDIndex32, enableVLAVectorization); + enableSIMDIndex32, enableVLAVectorization, + enableRuntimeLibrary); + // Apply pre-rewriting. + populateSparseTensorRewriting(prePatterns, options.enableRuntimeLibrary); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns)); // Apply sparsification and vector cleanup rewriting. RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -300,8 +300,10 @@ // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// -void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) { +void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, + bool /*enableRT*/) { patterns.add, ReshapeRewriter>(patterns.getContext()); + // TODO: If RT not enabled, rewrite concatenate ops, etc here. } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_dim.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_dim.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_dim.mlir @@ -0,0 +1,46 @@ +// Test with/without runtime library, the result should always be identical. + +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=true | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#DCSR = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed"] +}> + +module { + // + // Main driver. + // + func.func @entry() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %t1 = bufferization.alloc_tensor() : tensor<4x5xf64, #DCSR> + %t2 = bufferization.alloc_tensor(%c2, %c3) : tensor + + %d1_0 = tensor.dim %t1, %c0 : tensor<4x5xf64, #DCSR> + %d2_0 = tensor.dim %t2, %c0 : tensor + %d1_1 = tensor.dim %t1, %c1 : tensor<4x5xf64, #DCSR> + %d2_1 = tensor.dim %t2, %c1 : tensor + // CHECK: 4 + vector.print %d1_0 : index + // CHECK-NEXT: 2 + vector.print %d2_0 : index + // CHECK-NEXT: 5 + vector.print %d1_1 : index + // CHECK-NEXT: 3 + vector.print %d2_1 : index + return + } +} +