diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -368,6 +368,8 @@ I32EnumAttrCase<"PosMemSize", 1, "pos_mem_sz">, I32EnumAttrCase<"CrdMemSize", 2, "crd_mem_sz">, I32EnumAttrCase<"ValMemSize", 3, "val_mem_sz">, + I32EnumAttrCase<"DimOffset", 4, "dim_offset">, + I32EnumAttrCase<"DimStride", 5, "dim_stride">, ]> { let genSpecializedAttr = 0; let cppNamespace = SparseTensor_Dialect.cppNamespace; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -358,21 +358,44 @@ } def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>, + Arguments<(ins Optional:$source)>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; let description = [{ Returns an initial storage specifier value. A storage specifier value holds the level-sizes, position arrays, coordinate arrays, and the value array. + If this is a specifier for slices, it also holds the extra strides/offsets + for each tensor dimension. + + The sparse tensor slice support is currently in a unstable state, and is subject + to change in the future. Example: ```mlir + #CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ]}> + #CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] + }> + %0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR> + %1 = sparse_tensor.storage_specifier.init with %src + : !sparse_tensor.storage_specifier<#CSR> to + !sparse_tensor.storage_specifier<#CSR_SLICE> ``` }]; - let assemblyFormat = "attr-dict `:` qualified(type($result))"; + let builders = [ + OpBuilder<(ins "Type":$result), + [{ + build($_builder, $_state, result, Value()); + }]> + ]; + + let assemblyFormat = "attr-dict (`with` $source^)? `:` (`from` qualified(type($source))^ `to`)?" + " qualified(type($result))"; } def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get", [Pure]>, diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -620,6 +620,12 @@ const auto enc = md.getType().getEncoding(); const Level lvlRank = enc.getLvlRank(); + // TODO: + // if (mdKind == StorageSpecifierKind::DimOffset || + // mdKind == StorageSpecifierKind::DimStride) + // if (!enc.isSlice()) + // return op->emitError("requested slice data on non-slice tensor"); + if (mdKind != StorageSpecifierKind::ValMemSize) { if (!lvl) return op->emitError("missing level argument"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -21,12 +21,12 @@ // Helper methods. //===----------------------------------------------------------------------===// -static SmallVector getSpecifierFields(StorageSpecifierType tp) { +static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); const Level lvlRank = enc.getLvlRank(); - SmallVector result; + SmallVector result; // TODO: how can we get the lowering type for index type in the later pipeline // to be consistent? LLVM::StructureType does not allow index fields. auto sizeType = IntegerType::get(tp.getContext(), 64); @@ -35,6 +35,16 @@ getNumDataFieldsFromEncoding(enc)); result.push_back(lvlSizes); result.push_back(memSizes); + + if (enc.isSlice()) { + // Extra fields are required for the slice information. + auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); + auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); + + result.push_back(dimOffset); + result.push_back(dimStride); + } + return result; } @@ -49,11 +59,13 @@ constexpr uint64_t kLvlSizePosInSpecifier = 0; constexpr uint64_t kMemSizePosInSpecifier = 1; +constexpr uint64_t kDimOffsetPosInSpecifier = 2; +constexpr uint64_t kDimStridePosInSpecifier = 3; class SpecifierStructBuilder : public StructBuilder { private: Value extractField(OpBuilder &builder, Location loc, - ArrayRef indices) { + ArrayRef indices) const { return genCast(builder, loc, builder.create(loc, value, indices), builder.getIndexType()); @@ -71,36 +83,69 @@ assert(value); } - // Undef value for level-sizes, all zero values for memory-sizes. - static Value getInitValue(OpBuilder &builder, Location loc, Type structType); + // Undef value for dimension sizes, all zero value for memory sizes. + static Value getInitValue(OpBuilder &builder, Location loc, Type structType, + Value source); - Value lvlSize(OpBuilder &builder, Location loc, Level lvl); + Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const; void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size); - Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx); + Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const; + void setDimOffset(OpBuilder &builder, Location loc, Dimension dim, + Value size); + + Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const; + void setDimStride(OpBuilder &builder, Location loc, Dimension dim, + Value size); + + Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const; void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx, Value size); + + Value memSizeArray(OpBuilder &builder, Location loc) const; + void setMemSizeArray(OpBuilder &builder, Location loc, Value array); }; Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, - Type structType) { + Type structType, Value source) { Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); - auto memSizeArrayType = structType.cast() - .getBody()[kMemSizePosInSpecifier] - .cast(); + if (!source) { + auto memSizeArrayType = structType.cast() + .getBody()[kMemSizePosInSpecifier] + .cast(); + + Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); + // Fill memSizes array with zero. + for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) + md.setMemSize(builder, loc, i, zero); + } else { + // We copies non-slice information (memory sizes array) from source + SpecifierStructBuilder sourceMd(source); + md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc)); + } + return md; +} - Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); - // Fill memSizes array with zero. - for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) - md.setMemSize(builder, loc, i, zero); +/// Builds IR extracting the pos-th offset from the descriptor. +Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, + Dimension dim) const { + return builder.create( + loc, value, + ArrayRef({kDimOffsetPosInSpecifier, static_cast(dim)})); +} - return md; +/// Builds IR inserting the pos-th offset into the descriptor. +void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, + Dimension dim, Value size) { + value = builder.create( + loc, value, size, + ArrayRef({kDimOffsetPosInSpecifier, static_cast(dim)})); } /// Builds IR extracting the `lvl`-th level-size from the descriptor. Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc, - Level lvl) { + Level lvl) const { // This static_cast makes the narrowing of `lvl` explicit, as required // by the braces notation for the ctor. return extractField( @@ -119,18 +164,52 @@ size); } -/// Builds IR extracting the `fidx`-th memory-size from the descriptor. +/// Builds IR extracting the pos-th stride from the descriptor. +Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc, + Dimension dim) const { + return extractField( + builder, loc, + ArrayRef{kDimStridePosInSpecifier, static_cast(dim)}); +} + +/// Builds IR inserting the pos-th stride into the descriptor. +void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc, + Dimension dim, Value size) { + insertField( + builder, loc, + ArrayRef{kDimStridePosInSpecifier, static_cast(dim)}, + size); +} + +/// Builds IR extracting the pos-th memory size into the descriptor. Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, - FieldIndex fidx) { - return extractField(builder, loc, - ArrayRef{kMemSizePosInSpecifier, fidx}); + FieldIndex fidx) const { + return extractField( + builder, loc, + ArrayRef{kMemSizePosInSpecifier, static_cast(fidx)}); } /// Builds IR inserting the `fidx`-th memory-size into the descriptor. void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx, Value size) { - insertField(builder, loc, ArrayRef{kMemSizePosInSpecifier, fidx}, - size); + insertField( + builder, loc, + ArrayRef{kMemSizePosInSpecifier, static_cast(fidx)}, + size); +} + +/// Builds IR extracting the memory size array from the descriptor. +Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, + Location loc) const { + return builder.create(loc, value, + kMemSizePosInSpecifier); +} + +/// Builds IR inserting the memory size array into the descriptor. +void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, + Value array) { + value = builder.create(loc, value, array, + kMemSizePosInSpecifier); } } // namespace @@ -158,20 +237,37 @@ matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SpecifierStructBuilder spec(adaptor.getSpecifier()); - Value v; - if (op.getSpecifierKind() == StorageSpecifierKind::LvlSize) { - assert(op.getLevel().has_value()); - v = Base::onLvlSize(rewriter, op, spec, op.getLevel().value()); - } else { + switch (op.getSpecifierKind()) { + case StorageSpecifierKind::LvlSize: { + Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel())); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimOffset: { + Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel())); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimStride: { + Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel())); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::CrdMemSize: + case StorageSpecifierKind::PosMemSize: + case StorageSpecifierKind::ValMemSize: { auto enc = op.getSpecifier().getType().getEncoding(); StorageLayout layout(enc); - FieldIndex fidx = - layout.getMemRefFieldIndex(op.getSpecifierKind(), op.getLevel()); - v = Base::onMemSize(rewriter, op, spec, fidx); + std::optional lvl; + if (op.getLevel()) + lvl = (*op.getLevel()); + unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), lvl); + Value v = Base::onMemSize(rewriter, op, spec, idx); + rewriter.replaceOp(op, v); + return success(); } - - rewriter.replaceOp(op, v); - return success(); + } + llvm_unreachable("unrecognized specifer kind"); } }; @@ -179,12 +275,25 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl) { spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue()); return spec; } + static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, Dimension d) { + spec.setDimOffset(builder, op.getLoc(), d, op.getValue()); + return spec; + } + + static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, Dimension d) { + spec.setDimStride(builder, op.getLoc(), d, op.getValue()); + return spec; + } + static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx) { spec.setMemSize(builder, op.getLoc(), fidx, op.getValue()); @@ -196,10 +305,22 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl) { return spec.lvlSize(builder, op.getLoc(), lvl); } + + static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, Dimension d) { + return spec.dimOffset(builder, op.getLoc(), d); + } + + static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, Dimension d) { + return spec.dimStride(builder, op.getLoc(), d); + } + static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx) { return spec.memSize(builder, op.getLoc(), fidx); @@ -214,8 +335,9 @@ matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue( - rewriter, op.getLoc(), llvmType)); + rewriter.replaceOp( + op, SpecifierStructBuilder::getInitValue( + rewriter, op.getLoc(), llvmType, adaptor.getSource())); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -73,6 +73,20 @@ // memref ; values // struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes // +// If the sparse tensor is a slice (produced by `tensor.extract_slice` +// operation), instead of allocating a new sparse tensor for it, it reuses the +// same sets of MemRefs but attaching a additional set of slicing-metadata for +// per-dimension slice offset and stride. +// +// Slice on #COO storage of 2-dim matrix yields +// ;; Inherited from the original sparse tensors +// memref, ; positions-0, essentially [0,sz] +// memref ; AOS coordinates storage +// memref ; values +// struct<(array<2 x i64>, array<3 x i64>, ; lvl0, lvl1, 3xsizes +// ;; Extra slicing-metadata +// array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride. +// //===----------------------------------------------------------------------===// enum class SparseTensorFieldKind : uint32_t { diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -259,6 +259,17 @@ return %0 : index } +//// ----- +// +//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +// +//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { +// // _e_xpected-error@+1 {{requested slice data on non-slice tensor}} +// %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 +// : !sparse_tensor.storage_specifier<#SparseVector> to i64 +// return %0 : i64 +//} + // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -179,6 +179,25 @@ // ----- +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#SparseVector_Slice = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_metadata_init( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.init with %[[A]] : +// CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}> +func.func @sparse_metadata_init(%src : !sparse_tensor.storage_specifier<#SparseVector>) + -> !sparse_tensor.storage_specifier<#SparseVector_Slice> { + %0 = sparse_tensor.storage_specifier.init with %src : from !sparse_tensor.storage_specifier<#SparseVector> + to !sparse_tensor.storage_specifier<#SparseVector_Slice> + return %0 : !sparse_tensor.storage_specifier<#SparseVector_Slice> +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_get_md( @@ -191,6 +210,41 @@ return %0 : index } +// ----- + +#SparseVector_Slice = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_get_md( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_offset at 0 +// CHECK: return %[[T]] : index +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector_Slice>) -> index { + %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 + : !sparse_tensor.storage_specifier<#SparseVector_Slice> + return %0 : index +} + +// ----- + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_get_md( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_stride at 0 +// CHECK: return %[[T]] : index +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index { + %0 = sparse_tensor.storage_specifier.get %arg0 dim_stride at 0 + : !sparse_tensor.storage_specifier<#SparseVector> + return %0 : index +} + + // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>