diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -11,8 +11,6 @@ // In general, this file takes the approach of keeping "mechanism" (the // actual steps of applying a transformation) completely separate from // "policy" (heuristics for when and where to apply transformations). -// The only exception is in `SparseToSparseConversionStrategy`; for which, -// see further discussion there. // //===----------------------------------------------------------------------===// @@ -21,15 +19,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace bufferization { struct OneShotBufferizationOptions; } // namespace bufferization -// Forward. -class TypeConverter; - //===----------------------------------------------------------------------===// // The Sparsification pass. //===----------------------------------------------------------------------===// @@ -95,6 +91,12 @@ // The SparseTensorConversion pass. //===----------------------------------------------------------------------===// +/// Sparse tensor type converter into an opaque pointer. +class SparseTensorTypeToPtrConverter : public TypeConverter { +public: + SparseTensorTypeToPtrConverter(); +}; + /// Defines a strategy for implementing sparse-to-sparse conversion. /// `kAuto` leaves it up to the compiler to automatically determine /// the method used. `kViaCOO` converts the source tensor to COO and @@ -138,6 +140,22 @@ std::unique_ptr createSparseTensorConversionPass(const SparseTensorConversionOptions &options); +//===----------------------------------------------------------------------===// +// The SparseTensorCodegen pass. +//===----------------------------------------------------------------------===// + +/// Sparse tensor type converter into an actual buffer. +class SparseTensorTypeToBufferConverter : public TypeConverter { +public: + SparseTensorTypeToBufferConverter(); +}; + +/// Sets up sparse tensor conversion rules. +void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +std::unique_ptr createSparseTensorCodegenPass(); + //===----------------------------------------------------------------------===// // Other rewriting rules and passes. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -79,14 +79,14 @@ def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> { let summary = "Apply conversion rules to sparse tensor primitives and types"; let description = [{ - A pass that converts sparse tensor primitives to calls into a runtime - support library. All sparse tensor types are converted into opaque - pointers to the underlying sparse storage schemes. + A pass that converts sparse tensor primitives into calls into a runtime + support library. Sparse tensor types are converted into opaque pointers + to the underlying sparse storage schemes. - Note that this is a current implementation choice to keep the conversion - relatively simple. In principle, these primitives could also be - converted to actual elaborate IR code that implements the primitives - on the selected sparse tensor storage schemes. + The use of opaque pointers together with runtime support library keeps + the conversion relatively simple, but at the expense of IR opacity, + which obscures opportunities for subsequent optimization of the IR. + An alternative is provided by the SparseTensorCodegen pass. Example of the conversion: @@ -122,4 +122,28 @@ ]; } +def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> { + let summary = "Apply conversion rules to sparse tensor primitives and types"; + let description = [{ + A pass that converts sparse tensor types and primitives to actual + compiler visible buffers and actual compiler IR that implements these + primitives on the selected sparse tensor storage schemes. + + This pass provides an alternative to the SparseTensorConversion pass, + eliminating the dependence on a runtime support library, and providing + much more opportunities for subsequent compiler optimization of the + generated code. + + Example of the conversion: + + ```mlir + TBD + ``` + }]; + let constructor = "mlir::createSparseTensorCodegenPass()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + ]; +} + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ CodegenUtils.cpp DenseBufferizationPass.cpp Sparsification.cpp + SparseTensorCodegen.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp SparseTensorRewriting.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -0,0 +1,82 @@ +//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A pass that converts sparse tensor types and primitives to actual compiler +// visible buffers and actual compiler IR that implements these primitives on +// the selected sparse tensor storage schemes. This pass provides an alternative +// to the SparseTensorConversion pass, eliminating the dependence on a runtime +// support library, and providing much more opportunities for subsequent +// compiler optimization of the generated code. +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + +/// Maps each sparse tensor type to the appropriate buffer. +static Optional convertSparseTensorTypes(Type type) { + if (getSparseTensorEncoding(type) != nullptr) { + // TODO: this is just a dummy rule to get the ball rolling.... + RankedTensorType rTp = type.cast(); + return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType()); + } + return llvm::None; +} + +//===----------------------------------------------------------------------===// +// Conversion rules. +//===----------------------------------------------------------------------===// + +/// Sparse conversion rule for returns. +class SparseReturnConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Sparse tensor type conversion into an actual buffer. +//===----------------------------------------------------------------------===// + +mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertSparseTensorTypes); +} + +//===----------------------------------------------------------------------===// +// Public method for populating conversion rules. +//===----------------------------------------------------------------------===// + +/// Populates the given patterns list with conversion rules required for +/// the sparsification of linear algebra operations. +void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add(typeConverter, patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// // -// Convert sparse tensor primitives to calls into a runtime support library. -// Note that this is a current implementation choice to keep the conversion -// simple. In principle, these primitives could also be converted to actual -// elaborate IR code that implements the primitives on the selected sparse -// tensor storage schemes. +// A pass that converts sparse tensor primitives into calls into a runtime +// support library. Sparse tensor types are converted into opaque pointers +// to the underlying sparse storage schemes. The use of opaque pointers +// together with runtime support library keeps the conversion relatively +// simple, but at the expense of IR opacity, which obscures opportunities +// for subsequent optimization of the IR. An alternative is provided by +// the SparseTensorCodegen pass. // //===----------------------------------------------------------------------===// @@ -48,6 +50,13 @@ return LLVM::LLVMPointerType::get(builder.getI8Type()); } +/// Maps each sparse tensor type to an opaque pointer. +static Optional convertSparseTensorTypes(Type type) { + if (getSparseTensorEncoding(type) != nullptr) + return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); + return llvm::None; +} + /// Returns a function reference (first hit also inserts into module). Sets /// the "_emit_c_interface" on the function declaration when requested, /// so that LLVM lowering generates a wrapper function that takes care @@ -1345,6 +1354,7 @@ return success(); } }; + /// Sparse conversion rule for the output operator. class SparseTensorOutConverter : public OpConversionPattern { public: @@ -1387,6 +1397,15 @@ } // namespace +//===----------------------------------------------------------------------===// +// Sparse tensor type conversion into opaque pointer. +//===----------------------------------------------------------------------===// + +mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertSparseTensorTypes); +} + //===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -67,20 +67,6 @@ } }; -class SparseTensorTypeConverter : public TypeConverter { -public: - SparseTensorTypeConverter() { - addConversion([](Type type) { return type; }); - addConversion(convertSparseTensorTypes); - } - // Maps each sparse tensor type to an opaque pointer. - static Optional convertSparseTensorTypes(Type type) { - if (getSparseTensorEncoding(type) != nullptr) - return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); - return llvm::None; - } -}; - struct SparseTensorConversionPass : public SparseTensorConversionBase { @@ -93,7 +79,7 @@ void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - SparseTensorTypeConverter converter; + SparseTensorTypeToPtrConverter converter; ConversionTarget target(*ctx); // Everything in the sparse dialect must go! target.addIllegalDialect(); @@ -158,8 +144,49 @@ } }; +struct SparseTensorCodegenPass + : public SparseTensorCodegenBase { + + SparseTensorCodegenPass() = default; + SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + SparseTensorTypeToBufferConverter converter; + ConversionTarget target(*ctx); + // Everything in the sparse dialect must go! + target.addIllegalDialect(); + // All dynamic rules below accept new function, call, return, and various + // tensor and bufferization operations as legal output of the rewriting. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return converter.isSignatureLegal(op.getCalleeType()); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return converter.isLegal(op.getOperandTypes()); + }); + // Populate with rules and apply rewriting rules. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + populateSparseTensorCodegenPatterns(converter, patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + } // namespace +//===----------------------------------------------------------------------===// +// Strategy flag methods. +//===----------------------------------------------------------------------===// + SparseParallelizationStrategy mlir::sparseParallelizationStrategy(int32_t flag) { switch (flag) { @@ -199,6 +226,10 @@ } } +//===----------------------------------------------------------------------===// +// Pass creation methods. +//===----------------------------------------------------------------------===// + std::unique_ptr mlir::createSparsificationPass() { return std::make_unique(); } @@ -216,3 +247,7 @@ const SparseTensorConversionOptions &options) { return std::make_unique(options); } + +std::unique_ptr mlir::createSparseTensorCodegenPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +// TODO: just a dumb memref rewriting to get the ball rolling.... + +// CHECK-LABEL: func @sparse_nop( +// CHECK-SAME: %[[A:.*]]: memref) -> memref { +// CHECK: return %[[A]] : memref +func.func @sparse_nop(%arg0: tensor) -> tensor { + return %arg0 : tensor +} diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -25,6 +25,13 @@ dimOrdering = affine_map<(i,j,k) -> (k,i,j)> }> +// CHECK-LABEL: func @sparse_nop( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: return %[[A]] : !llvm.ptr +func.func @sparse_nop(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + // CHECK-LABEL: func @sparse_dim1d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = arith.constant 0 : index