diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -320,6 +320,8 @@ I32EnumAttrCase<"PtrMemSize", 1, "ptr_mem_sz">, I32EnumAttrCase<"IdxMemSize", 2, "idx_mem_sz">, I32EnumAttrCase<"ValMemSize", 3, "val_mem_sz">, + I32EnumAttrCase<"DimOffset", 4, "dim_offset">, + I32EnumAttrCase<"DimStride", 5, "dim_stride">, ]> { let genSpecializedAttr = 0; let cppNamespace = SparseTensor_Dialect.cppNamespace; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -345,20 +345,34 @@ } def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>, + Arguments<(ins Optional:$source)>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; let description = [{ Returns an initial storage specifier value. A storage specifier value holds the sizes for tensor dimensions, pointer arrays, index arrays, and the value array. + If this is a specifier for slices, it also holds the extra strides/offsets for tensor + dimensions. Example: ```mlir %0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR> + %1 = sparse_tensor.storage_specifier.init with %src + : !sparse_tensor.storage_specifier<#CSR> to + !sparse_tensor.storage_specifier<#CSR_SLICE> ``` }]; - let assemblyFormat = "attr-dict `:` qualified(type($result))"; + let builders = [ + OpBuilder<(ins "Type":$result), + [{ + build($_builder, $_state, result, Value()); + }]> + ]; + + let assemblyFormat = "attr-dict (`with` $source^)? `:` (`from` qualified(type($source))^ `to`)?" + " qualified(type($result))"; } def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get", [Pure]>, diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -616,6 +616,12 @@ const auto enc = md.getType().getEncoding(); const Level lvlRank = enc.getLvlRank(); + // TODO: + // if (mdKind == StorageSpecifierKind::DimOffset || + // mdKind == StorageSpecifierKind::DimStride) + // if (!enc.isSlice()) + // return op->emitError("requested slice data on non-slice tensor"); + if (mdKind != StorageSpecifierKind::ValMemSize) { if (!lvl) return op->emitError("missing level argument"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -21,12 +21,12 @@ // Helper methods. //===----------------------------------------------------------------------===// -static SmallVector getSpecifierFields(StorageSpecifierType tp) { +static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); const Level lvlRank = enc.getLvlRank(); - SmallVector result; + SmallVector result; // TODO: how can we get the lowering type for index type in the later pipeline // to be consistent? LLVM::StructureType does not allow index fields. auto indexType = IntegerType::get(tp.getContext(), 64); @@ -35,6 +35,16 @@ getNumDataFieldsFromEncoding(enc)); result.push_back(dimSizes); result.push_back(memSizes); + + if (enc.isSlice()) { + // Extra fields are required for the slice information + auto dimOffset = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank); + auto dimStride = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank); + + result.push_back(dimOffset); + result.push_back(dimStride); + } + return result; } @@ -49,11 +59,13 @@ constexpr uint64_t kDimSizePosInSpecifier = 0; constexpr uint64_t kMemSizePosInSpecifier = 1; +constexpr uint64_t kDimOffsetPosInSpecifier = 2; +constexpr uint64_t kDimStridePosInSpecifier = 3; class SpecifierStructBuilder : public StructBuilder { private: Value extractField(OpBuilder &builder, Location loc, - ArrayRef indices) { + ArrayRef indices) const { return genCast(builder, loc, builder.create(loc, value, indices), builder.getIndexType()); @@ -72,34 +84,63 @@ } // Undef value for dimension sizes, all zero value for memory sizes. - static Value getInitValue(OpBuilder &builder, Location loc, Type structType); + static Value getInitValue(OpBuilder &builder, Location loc, Type structType, + Value source); + + Value dimOffset(OpBuilder &builder, Location loc, unsigned dim) const; + void setDimOffset(OpBuilder &builder, Location loc, unsigned dim, Value size); - Value dimSize(OpBuilder &builder, Location loc, unsigned dim); + Value dimSize(OpBuilder &builder, Location loc, unsigned dim) const; void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size); - Value memSize(OpBuilder &builder, Location loc, unsigned pos); + Value dimStride(OpBuilder &builder, Location loc, unsigned dim) const; + void setDimStride(OpBuilder &builder, Location loc, unsigned dim, Value size); + + Value memSize(OpBuilder &builder, Location loc, unsigned pos) const; void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size); + + Value memSizeArray(OpBuilder &builder, Location loc) const; + void setMemSizeArray(OpBuilder &builder, Location loc, Value array); }; Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, - Type structType) { + Type structType, Value source) { Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); - auto memSizeArrayType = structType.cast() - .getBody()[kMemSizePosInSpecifier] - .cast(); + if (!source) { + auto memSizeArrayType = structType.cast() + .getBody()[kMemSizePosInSpecifier] + .cast(); + + Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); + // Fill memSizes array with zero. + for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) + md.setMemSize(builder, loc, i, zero); + } else { + // We copies non-slice information (memory sizes array) from source + SpecifierStructBuilder sourceMd(source); + md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc)); + } + return md; +} - Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); - // Fill memSizes array with zero. - for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) - md.setMemSize(builder, loc, i, zero); +/// Builds IR extracting the pos-th offset from the descriptor. +Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, + unsigned dim) const { + return builder.create( + loc, value, ArrayRef({kDimOffsetPosInSpecifier, dim})); +} - return md; +/// Builds IR inserting the pos-th offset into the descriptor. +void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimOffsetPosInSpecifier, dim})); } /// Builds IR inserting the pos-th size into the descriptor. Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc, - unsigned dim) { + unsigned dim) const { return extractField(builder, loc, ArrayRef{kDimSizePosInSpecifier, dim}); } @@ -112,9 +153,23 @@ size); } +/// Builds IR extracting the pos-th stride from the descriptor. +Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc, + unsigned dim) const { + return extractField(builder, loc, + ArrayRef{kDimStridePosInSpecifier, dim}); +} + +/// Builds IR inserting the pos-th stride into the descriptor. +void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + insertField(builder, loc, ArrayRef{kDimStridePosInSpecifier, dim}, + size); +} + /// Builds IR extracting the pos-th memory size into the descriptor. Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, - unsigned pos) { + unsigned pos) const { return extractField(builder, loc, ArrayRef{kMemSizePosInSpecifier, pos}); } @@ -126,6 +181,20 @@ size); } +/// Builds IR extracting the memory size array from the descriptor. +Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, + Location loc) const { + return builder.create(loc, value, + kMemSizePosInSpecifier); +} + +/// Builds IR inserting the memory size array into the descriptor. +void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, + Value array) { + value = builder.create(loc, value, array, + kMemSizePosInSpecifier); +} + } // namespace //===----------------------------------------------------------------------===// @@ -151,22 +220,40 @@ matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SpecifierStructBuilder spec(adaptor.getSpecifier()); - Value v; - if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) { - v = Base::onDimSize(rewriter, op, spec, - op.getDim().value().getZExtValue()); - } else { + switch (op.getSpecifierKind()) { + case StorageSpecifierKind::DimOffset: { + Value v = + Base::onDimOffset(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimSize: { + Value v = + Base::onDimSize(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimStride: { + Value v = + Base::onDimStride(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::IdxMemSize: + case StorageSpecifierKind::PtrMemSize: + case StorageSpecifierKind::ValMemSize: { auto enc = op.getSpecifier().getType().getEncoding(); StorageLayout layout(enc); std::optional dim; if (op.getDim()) - dim = op.getDim().value().getZExtValue(); + dim = (*op.getDim()).getZExtValue(); unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim); - v = Base::onMemSize(rewriter, op, spec, idx); + Value v = Base::onMemSize(rewriter, op, spec, idx); + rewriter.replaceOp(op, v); + return success(); } - - rewriter.replaceOp(op, v); - return success(); + } + llvm_unreachable("unrecognized specifer kind"); } }; @@ -174,12 +261,25 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + + static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimOffset(builder, op.getLoc(), d, op.getValue()); + return spec; + } + static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, unsigned d) { spec.setDimSize(builder, op.getLoc(), d, op.getValue()); return spec; } + static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimStride(builder, op.getLoc(), d, op.getValue()); + return spec; + } + static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, unsigned i) { spec.setMemSize(builder, op.getLoc(), i, op.getValue()); @@ -191,12 +291,24 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + + static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, unsigned d) { + return spec.dimOffset(builder, op.getLoc(), d); + } + static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op, - SpecifierStructBuilder &spec, unsigned d) { + const SpecifierStructBuilder &spec, unsigned d) { return spec.dimSize(builder, op.getLoc(), d); } + + static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, unsigned d) { + return spec.dimStride(builder, op.getLoc(), d); + } + static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, - SpecifierStructBuilder &spec, unsigned i) { + const SpecifierStructBuilder &spec, unsigned i) { return spec.memSize(builder, op.getLoc(), i); } }; @@ -209,8 +321,9 @@ matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue( - rewriter, op.getLoc(), llvmType)); + rewriter.replaceOp( + op, SpecifierStructBuilder::getInitValue( + rewriter, op.getLoc(), llvmType, adaptor.getSource())); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -259,6 +259,17 @@ return %0 : index } +//// ----- +// +//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +// +//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { +// // _e_xpected-error@+1 {{requested slice data on non-slice tensor}} +// %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 +// : !sparse_tensor.storage_specifier<#SparseVector> to i64 +// return %0 : i64 +//} + // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -179,6 +179,25 @@ // ----- +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#SparseVector_Slice = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_metadata_init( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.init with %[[A]] : +// CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}> +func.func @sparse_metadata_init(%src : !sparse_tensor.storage_specifier<#SparseVector>) + -> !sparse_tensor.storage_specifier<#SparseVector_Slice> { + %0 = sparse_tensor.storage_specifier.init with %src : from !sparse_tensor.storage_specifier<#SparseVector> + to !sparse_tensor.storage_specifier<#SparseVector_Slice> + return %0 : !sparse_tensor.storage_specifier<#SparseVector_Slice> +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_get_md( @@ -191,6 +210,41 @@ return %0 : index } +// ----- + +#SparseVector_Slice = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_get_md( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_offset at 0 +// CHECK: return %[[T]] : index +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector_Slice>) -> index { + %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 + : !sparse_tensor.storage_specifier<#SparseVector_Slice> + return %0 : index +} + +// ----- + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_get_md( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_stride at 0 +// CHECK: return %[[T]] : index +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index { + %0 = sparse_tensor.storage_specifier.get %arg0 dim_stride at 0 + : !sparse_tensor.storage_specifier<#SparseVector> + return %0 : index +} + + // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>