diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -49,6 +49,17 @@ vectorLength, enableSIMDIndex32); } + // These options must be kept in sync with `SparseTensorConversionBase`. + PassOptions::Option sparseToSparse{ + *this, "s2s-strategy", + desc("Set the strategy for sparse-to-sparse conversion"), init(0)}; + + /// Projects out the options for `createSparsificationPass`. + SparseTensorConversionOptions sparseTensorConversionOptions() const { + return SparseTensorConversionOptions( + sparseToSparseConversionStrategy(sparseToSparse)); + } + // These options must be kept in sync with `ConvertVectorToLLVMBase`. // TODO(wrengr): does `indexOptimizations` differ from `enableSIMDIndex32`? PassOptions::Option reassociateFPReductions{ diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -8,6 +8,12 @@ // // This header file defines prototypes of all sparse tensor passes. // +// In general, this file takes the approach of keeping "mechanism" (the +// actual steps of applying a transformation) completely separate from +// "policy" (heuristics for when and where to apply transformations). +// The only exception is in `SparseToSparseConversionStrategy`; for which, +// see further discussion there. +// //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_ @@ -21,6 +27,10 @@ // Forward. class TypeConverter; +//===----------------------------------------------------------------------===// +// The Sparsification pass. +//===----------------------------------------------------------------------===// + /// Defines a parallelization strategy. Any independent loop is a candidate /// for parallelization. The loop is made parallel if (1) allowed by the /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse @@ -51,7 +61,7 @@ /// Converts command-line vectorization flag to the strategy enum. SparseVectorizationStrategy sparseVectorizationStrategy(int32_t flag); -/// Sparsification options. +/// Options for the Sparsification pass. struct SparsificationOptions { SparsificationOptions(SparseParallelizationStrategy p, SparseVectorizationStrategy v, unsigned vl, bool e) @@ -71,14 +81,56 @@ RewritePatternSet &patterns, const SparsificationOptions &options = SparsificationOptions()); -/// Sets up sparse tensor conversion rules. -void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns); - std::unique_ptr createSparsificationPass(); std::unique_ptr createSparsificationPass(const SparsificationOptions &options); + +//===----------------------------------------------------------------------===// +// The SparseTensorConversion pass. +//===----------------------------------------------------------------------===// + +/// Defines a strategy for implementing sparse-to-sparse conversion. +/// `kAuto` leaves it up to the compiler to automatically determine +/// the method used. `kViaCOO` converts the source tensor to COO and +/// then converts the COO to the target format. `kDirect` converts +/// directly via the algorithm in ; +/// however, beware that there are many formats not supported by this +/// conversion method. +/// +/// The presence of the `kAuto` option violates our usual goal of keeping +/// policy completely separated from mechanism. The reason it exists is +/// because (at present) this strategy can only be specified on a per-file +/// basis. To see why this is a problem, note that `kDirect` cannot +/// support certain conversions; so if there is no `kAuto` setting, +/// then whenever a file contains a single non-`kDirect`-able conversion +/// the user would be forced to use `kViaCOO` for all conversions in +/// that file! In the future, instead of using this enum as a `Pass` +/// option, we could instead move it to being an attribute on the +/// conversion op; at which point `kAuto` would no longer be necessary. +enum class SparseToSparseConversionStrategy { kAuto, kViaCOO, kDirect }; + +/// Converts command-line sparse2sparse flag to the strategy enum. +SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag); + +/// SparseTensorConversion options. +struct SparseTensorConversionOptions { + SparseTensorConversionOptions(SparseToSparseConversionStrategy s2s) + : sparseToSparseStrategy(s2s) {} + SparseTensorConversionOptions() + : SparseTensorConversionOptions(SparseToSparseConversionStrategy::kAuto) { + } + SparseToSparseConversionStrategy sparseToSparseStrategy; +}; + +/// Sets up sparse tensor conversion rules. +void populateSparseTensorConversionPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + const SparseTensorConversionOptions &options = + SparseTensorConversionOptions()); + std::unique_ptr createSparseTensorConversionPass(); +std::unique_ptr +createSparseTensorConversionPass(const SparseTensorConversionOptions &options); //===----------------------------------------------------------------------===// // Registration. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -114,6 +114,10 @@ "sparse_tensor::SparseTensorDialect", "vector::VectorDialect", ]; + let options = [ + Option<"sparseToSparse", "s2s-strategy", "int32_t", "0", + "Set the strategy for sparse-to-sparse conversion">, + ]; } #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -33,7 +33,8 @@ pm.addNestedPass(createLinalgGeneralizationPass()); pm.addPass(createLinalgElementwiseOpFusionPass()); pm.addPass(createSparsificationPass(options.sparsificationOptions())); - pm.addPass(createSparseTensorConversionPass()); + pm.addPass(createSparseTensorConversionPass( + options.sparseTensorConversionOptions())); pm.addNestedPass(createLinalgBufferizePass()); pm.addNestedPass(vector::createVectorBufferizePass()); pm.addNestedPass(createConvertLinalgToLoopsPass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -453,7 +453,19 @@ /// Sparse conversion rule for the convert operator. class SparseTensorConvertConverter : public OpConversionPattern { + /// Options to control sparse code generation. + SparseTensorConversionOptions options; + +public: using OpConversionPattern::OpConversionPattern; + SparseTensorConvertConverter(MLIRContext *context, + SparseTensorConversionOptions o) + : OpConversionPattern(context), options(o) {} + SparseTensorConvertConverter(TypeConverter &typeConverter_, + MLIRContext *context, + SparseTensorConversionOptions o) + : OpConversionPattern(typeConverter_, context), options(o) {} + LogicalResult matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -825,14 +837,17 @@ /// Populates the given patterns list with conversion rules required for /// the sparsification of linear algebra operations. -void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void mlir::populateSparseTensorConversionPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + const SparseTensorConversionOptions &options) { patterns.add(typeConverter, patterns.getContext()); + SparseTensorInitConverter, SparseTensorReleaseConverter, + SparseTensorToPointersConverter, SparseTensorToIndicesConverter, + SparseTensorToValuesConverter, SparseTensorLoadConverter, + SparseTensorLexInsertConverter, SparseTensorExpandConverter, + SparseTensorCompressConverter, SparseTensorOutConverter>( + typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext(), options); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -73,6 +73,13 @@ struct SparseTensorConversionPass : public SparseTensorConversionBase { + + SparseTensorConversionPass() = default; + SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; + SparseTensorConversionPass(const SparseTensorConversionOptions &options) { + sparseToSparse = static_cast(options.sparseToSparseStrategy); + } + void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -106,11 +113,14 @@ target .addLegalDialect(); + // Translate strategy flags to strategy options. + SparseTensorConversionOptions options( + sparseToSparseConversionStrategy(sparseToSparse)); // Populate with rules and apply rewriting rules. populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); - populateSparseTensorConversionPatterns(converter, patterns); + populateSparseTensorConversionPatterns(converter, patterns, options); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -146,6 +156,18 @@ } } +SparseToSparseConversionStrategy +mlir::sparseToSparseConversionStrategy(int32_t flag) { + switch (flag) { + default: + return SparseToSparseConversionStrategy::kAuto; + case 1: + return SparseToSparseConversionStrategy::kViaCOO; + case 2: + return SparseToSparseConversionStrategy::kDirect; + } +} + std::unique_ptr mlir::createSparsificationPass() { return std::make_unique(); } @@ -158,3 +180,8 @@ std::unique_ptr mlir::createSparseTensorConversionPass() { return std::make_unique(); } + +std::unique_ptr mlir::createSparseTensorConversionPass( + const SparseTensorConversionOptions &options) { + return std::make_unique(options); +}