diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -134,12 +134,19 @@ const SparseTensorConversionOptions &options = SparseTensorConversionOptions()); -std::unique_ptr createDenseBufferizationPass( - const bufferization::OneShotBufferizationOptions &options); std::unique_ptr createSparseTensorConversionPass(); std::unique_ptr createSparseTensorConversionPass(const SparseTensorConversionOptions &options); +//===----------------------------------------------------------------------===// +// Other rewriting rules and passes. +//===----------------------------------------------------------------------===// + +void populateSparseTensorRewriting(RewritePatternSet &patterns); + +std::unique_ptr createDenseBufferizationPass( + const bufferization::OneShotBufferizationOptions &options); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -22,7 +22,6 @@ LinalgStrategyPasses.cpp NamedOpConversions.cpp Promotion.cpp - SparseTensorRewriting.cpp Split.cpp SplitReduction.cpp Tiling.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1717,12 +1717,8 @@ // Add elementwise op fusion patterns. populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); - populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); - // Add the sparse tensor rewriting patterns. - populateSparseTensorRewriting(patterns); - // General canonicalization patterns. AffineApplyOp::getCanonicalizationPatterns(patterns, context); GenericOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -52,11 +52,6 @@ OpPassManager &pm, const SparseCompilerOptions &options) { // TODO(wrengr): ensure the original `pm` is for ModuleOp pm.addNestedPass(createLinalgGeneralizationPass()); - // TODO(springerm): Reactivate element-wise op fusion pass. This pass does not - // fit well with bufferization because it replaces unused "out" operands of - // LinalgOps with InitTensorOps. This would result in additional buffer - // allocations during bufferization. - // pm.addPass(createLinalgElementwiseOpFusionPass()); pm.addPass( bufferization::createTensorCopyInsertionPass(getBufferizationOptions( /*analysisOnly=*/options.testBufferizationAnalysisOnly))); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ Sparsification.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp + SparseTensorRewriting.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -49,13 +49,17 @@ void runOnOperation() override { auto *ctx = &getContext(); - RewritePatternSet patterns(ctx); + // Apply pre-rewriting. + RewritePatternSet prePatterns(ctx); + populateSparseTensorRewriting(prePatterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns)); // Translate strategy flags to strategy options. SparsificationOptions options( sparseParallelizationStrategy(parallelization), sparseVectorizationStrategy(vectorization), vectorLength, enableSIMDIndex32, enableVLAVectorization); - // Apply rewriting. + // Apply sparsification and vector cleanup rewriting. + RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); vector::populateVectorToVectorCanonicalizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp rename from mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -6,20 +6,16 @@ // //===----------------------------------------------------------------------===// // -// This file implements linalg dialect rewriting specific to sparse tensors. -// -// Sparsity should be mostly transparent to the linalg dialect optimizations -// (i.e., the dense and sparse take the same path). However, in some cases, -// optimizations only make sense in the context of sparse tensors. This file -// implements such sparsity specific rewriting rules. +// This file implements rewriting rules that are specific to sparse tensors. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" @@ -98,6 +94,7 @@ //===---------------------------------------------------------------------===// namespace { + /// Rewriting rule that converts two kernels: /// /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) @@ -114,6 +111,7 @@ /// a fusion may actually reduce the asymptotic complexity of the kernel, /// since intermediate results may be nullified. struct FuseSparseMultiplyOverAdd : public OpRewritePattern { +public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp op, @@ -194,13 +192,55 @@ mapper.map(a, b->addArgument(a.getType(), a.getLoc())); } }; + +/// Sparse rewriting rule for reshape operator. +template +struct ReshapeRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReshapeOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto encDst = getSparseTensorEncoding(op.getResult().getType()); + auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); + // Since a pure dense expansion is very cheap (change of view), for + // a sparse2dense or dense2sparse, we can simply unfuse a sparse + // conversion from the reshape operation itself. + // All other cases are handled elsewhere. + if (encDst && encSrc) { + return failure(); + } else if (encSrc) { + RankedTensorType rtp = + op.getSrc().getType().template cast(); + auto denseTp = + RankedTensorType::get(rtp.getShape(), rtp.getElementType()); + auto convert = rewriter.create(loc, denseTp, op.getSrc()); + op->setOperand(0, convert); + return success(); + } else if (encDst) { + RankedTensorType rtp = + op.getResult().getType().template cast(); + auto denseTp = + RankedTensorType::get(rtp.getShape(), rtp.getElementType()); + auto reshape = rewriter.create(loc, denseTp, op.getSrc(), + op.getReassociation()); + Value convert = rewriter.create(loc, rtp, reshape); + rewriter.replaceOp(op, convert); + return success(); + } + return failure(); + } +}; + } // namespace //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// -void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) { - auto *context = patterns.getContext(); - patterns.add(context); +void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) { + // TODO(springerm): enable FuseSparseMultiplyOverAdd + patterns.add, + ReshapeRewriter>(patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1802,46 +1802,6 @@ SparsificationOptions options; }; -/// Sparse rewriting rule for reshape operator. -template -struct ReshapeRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReshapeOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - auto encDst = getSparseTensorEncoding(op.getResult().getType()); - auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); - // Since a pure dense expansion is very cheap (change of view), for - // a sparse2dense or dense2sparse, we can simply unfuse a sparse - // conversion from the reshape operation itself. - // All other cases are handled elsewhere. - if (encDst && encSrc) { - return failure(); - } else if (encSrc) { - RankedTensorType rtp = - op.getSrc().getType().template cast(); - auto denseTp = - RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto convert = rewriter.create(loc, denseTp, op.getSrc()); - op->setOperand(0, convert); - return success(); - } else if (encDst) { - RankedTensorType rtp = - op.getResult().getType().template cast(); - auto denseTp = - RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto reshape = rewriter.create(loc, denseTp, op.getSrc(), - op.getReassociation()); - Value convert = rewriter.create(loc, rtp, reshape); - rewriter.replaceOp(op, convert); - return success(); - } - return failure(); - } -}; - } // namespace /// Populates the given patterns list with rewriting rules required for @@ -1849,6 +1809,4 @@ void mlir::populateSparsificationPatterns( RewritePatternSet &patterns, const SparsificationOptions &options) { patterns.add(patterns.getContext(), options); - patterns.add, - ReshapeRewriter>(patterns.getContext()); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2115,6 +2115,7 @@ ":SparseTensorDialect", ":SparseTensorPassIncGen", ":SparseTensorUtils", + ":Support", ":TensorDialect", ":Transforms", ":VectorDialect",