diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -161,6 +161,22 @@ std::unique_ptr createSparseTensorCodegenPass(); +//===----------------------------------------------------------------------===// +// The SparseTensorStorageExpansion pass. +//===----------------------------------------------------------------------===// + +/// Sparse tensor type converter into an actual buffer. +class SparseTensorStorageTupleExpander : public TypeConverter { +public: + SparseTensorStorageTupleExpander(); +}; + +/// Sets up sparse tensor conversion rules. +void populateSparseTensorStorageExpansionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +std::unique_ptr createSparseTensorStorageExpansionPass(); + //===----------------------------------------------------------------------===// // Other rewriting rules and passes. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -146,4 +146,33 @@ ]; } +def SparseTensorStorageExpansion : Pass<"sparse-tensor-storage-expansion", "ModuleOp"> { + let summary = "Expand compounded sparse tensor storage into individual SSA values"; + let description = [{ + A pass that expands sparse tensor storage (aggregated by tuple) into + individual SSA values. It also lowers sparse tensor storage operations, + e.g., sparse_tensor.storage_get and sparse_tensor.storage_set + + Example of the conversion: + + ```mlir + Before: + func.func @sparse_storage_set(%arg0: tuple, memref, f64>) -> + tuple, memref, f64> { + return %arg0 : tuple, memref, f64> + } + After: + func.func @sparse_storage_set(%arg0: memref, %arg1: memref, %arg2: f64) -> + (memref, memref, f64) { + return %arg0, %arg1, %arg2 : memref, memref, f64 + } + ``` + }]; + let constructor = "mlir::createSparseTensorStorageExpansionPass()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + ]; +} + + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ SparseTensorConversion.cpp SparseTensorPasses.cpp SparseTensorRewriting.cpp + SparseTensorStorageExpansion.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -24,6 +24,7 @@ #define GEN_PASS_DEF_SPARSIFICATIONPASS #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCODEGEN +#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -157,7 +158,10 @@ SparseTensorTypeToBufferConverter converter; ConversionTarget target(*ctx); // Everything in the sparse dialect must go! - target.addIllegalDialect(); + // Except sparse_tensor.storage_get and sparse_tensor.storage_set. + target.addDynamicallyLegalDialect([](Operation *op) { + return isa(op) || isa(op); + }); // All dynamic rules below accept new function, call, return. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); @@ -185,6 +189,44 @@ } }; +struct SparseTensorStorageExpansionPass + : public impl::SparseTensorStorageExpansionBase< + SparseTensorStorageExpansionPass> { + + SparseTensorStorageExpansionPass() = default; + SparseTensorStorageExpansionPass( + const SparseTensorStorageExpansionPass &pass) = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + SparseTensorStorageTupleExpander converter; + ConversionTarget target(*ctx); + // Now, everything in the sparse dialect must go! + target.addIllegalDialect(); + // All dynamic rules below accept new function, call, return. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return converter.isSignatureLegal(op.getCalleeType()); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return converter.isLegal(op.getOperandTypes()); + }); + // Populate with rules and apply rewriting rules. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + populateSparseTensorStorageExpansionPatterns(converter, patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -255,3 +297,7 @@ std::unique_ptr mlir::createSparseTensorCodegenPass() { return std::make_unique(); } + +std::unique_ptr mlir::createSparseTensorStorageExpansionPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp @@ -0,0 +1,96 @@ +//===- SparseTensorStorageExpansion.cpp - Sparse tensor storage expansion ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The sparse tensor storage expansion pass expands the nested storage for +// sparse tensors (using tuple) to flattened SSA values. +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + +/// Expands sparse tensor storage tuple. +static Optional +convertSparseTensorStorageTuple(Type t, SmallVectorImpl &result) { + if (auto tuple = t.dyn_cast()) { + // Note that it does not handle nest tuples, but it is fine + // for sparse compiler as they will not be generated. + result.append(tuple.getTypes().begin(), tuple.getTypes().end()); + return success(); + } + return llvm::None; +} + +//===----------------------------------------------------------------------===// +// Conversion rules. +//===----------------------------------------------------------------------===// + +/// Sparse tensor storage conversion rule for returns. +class SparseStorageReturnConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector flattened; + for (auto operand : adaptor.getOperands()) { + if (auto cast = + dyn_cast(operand.getDefiningOp()); + cast && cast->getResultTypes()[0].isa()) + // An unrealized_conversion_cast will be inserted by type converter to + // inter-mix the gap between 1:N conversion between tuple and types. + // In this case, take the operands in the cast and replace the tuple + // output with the flattened type array. + flattened.append(cast.getOperands().begin(), cast.getOperands().end()); + else + flattened.push_back(operand); + } + // Create a return with the flattened value extracted from tuple. + rewriter.replaceOpWithNewOp(op, flattened); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Sparse tensor storage expansion +//===----------------------------------------------------------------------===// + +mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() { + addConversion([](Type type) { return type; }); + addConversion(convertSparseTensorStorageTuple); +} + +//===----------------------------------------------------------------------===// +// Public method for populating conversion rules. +//===----------------------------------------------------------------------===// + +/// Populates the given patterns list with conversion rules required +/// to expand compounded sparse tensor tuples. +void mlir::populateSparseTensorStorageExpansionPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + patterns.add(typeConverter, + patterns.getContext()); +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s + +// CHECK-LABEL: func @sparse_storage_expand( +// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, +// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, +// CHECK-SAME: %[[TMP_arg2:.*]]: f64 +// CHECK return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]] +func.func @sparse_storage_expand(%arg0: tuple, memref, f64>) + -> tuple, memref, f64> { + return %arg0 : tuple, memref, f64> +}