diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -14,6 +14,12 @@ using namespace mlir; using namespace sparse_tensor; +namespace { + +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); @@ -34,10 +40,9 @@ getSpecifierFields(tp)); } -StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() { - addConversion([](Type type) { return type; }); - addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); }); -} +//===----------------------------------------------------------------------===// +// Specifier struct builder. +//===----------------------------------------------------------------------===// constexpr uint64_t kDimSizePosInSpecifier = 0; constexpr uint64_t kMemSizePosInSpecifier = 1; @@ -102,6 +107,21 @@ loc, value, size, ArrayRef({kMemSizePosInSpecifier, pos})); } +} // namespace + +//===----------------------------------------------------------------------===// +// The sparse storage specifier type converter (defined in Passes.h). +//===----------------------------------------------------------------------===// + +StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); }); +} + +//===----------------------------------------------------------------------===// +// Storage specifier conversion rules. +//===----------------------------------------------------------------------===// + template class SpecifierGetterSetterOpConverter : public OpConversionPattern { public: @@ -176,6 +196,10 @@ } }; +//===----------------------------------------------------------------------===// +// Public method for populating conversion rules. +//===----------------------------------------------------------------------===// + void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, RewritePatternSet &patterns) { patterns.add &fields) { @@ -65,6 +75,10 @@ }); } +//===----------------------------------------------------------------------===// +// StorageLayout methods. +//===----------------------------------------------------------------------===// + unsigned StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional dim) const { unsigned fieldIdx = -1u; @@ -89,6 +103,10 @@ return getMemRefFieldIndex(toFieldKind(kind), dim); } +//===----------------------------------------------------------------------===// +// StorageTensorSpecifier methods. +//===----------------------------------------------------------------------===// + Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, RankedTensorType rtp) { return builder.create( @@ -114,6 +132,10 @@ createIndexCast(builder, loc, v, getFieldType(kind, dim))); } +//===----------------------------------------------------------------------===// +// Public methods. +//===----------------------------------------------------------------------===// + constexpr uint64_t kDataFieldStartingIdx = 0; void sparse_tensor::foreachFieldInSparseTensor(