diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -293,6 +293,8 @@ I32EnumAttrCase<"PtrMemSize", 1, "ptr_mem_sz">, I32EnumAttrCase<"IdxMemSize", 2, "idx_mem_sz">, I32EnumAttrCase<"ValMemSize", 3, "val_mem_sz">, + I32EnumAttrCase<"DimOffset", 4, "dim_offset">, + I32EnumAttrCase<"DimStride", 5, "dim_stride">, ]> { let genSpecializedAttr = 0; let cppNamespace = SparseTensor_Dialect.cppNamespace; @@ -312,13 +314,21 @@ def IsSparseTensorPred : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">; +def IsSparseTensorSlicePred + : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && " + " ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">; + // The following four follow the same idiom as `TensorOf`, `AnyTensor`, // `RankedTensorOf`, `AnyRankedTensor`. class SparseTensorOf allowedTypes> : TensorOf; +class SparseTensorSliceOf allowedTypes> + : TensorOf; + def AnySparseTensor : SparseTensorOf<[AnyType]>; +def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>; class RankedSparseTensorOf allowedTypes> : RankedTensorOf; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -205,21 +205,65 @@ let hasVerifier = 1; } +def SparseTensor_ToSliceOffsetOp : SparseTensor_Op<"slice.offset", [Pure]>, + Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>, + Results<(outs Index:$offset)> { + let summary = "Extracts the offset of the sparse tensor slice at the given dimension"; + let description = [{ + Example: + + ```mlir + %1 = sparse_tensor.slice.offset %0 at 1 : tensor<64x64xf64, #Slice> + ``` + }]; + let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)"; + let hasVerifier = 1; +} + +def SparseTensor_ToSliceStrideOp : SparseTensor_Op<"slice.stride", [Pure]>, + Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>, + Results<(outs Index:$stride)> { + let summary = "Extracts the stride of the sparse tensor slice at the given dimension"; + let description = [{ + Example: + + ```mlir + %1 = sparse_tensor.slice.stride %0 at 1 : tensor<64x64xf64, #Slice> + ``` + }]; + let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)"; + let hasVerifier = 1; +} + def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>, + Arguments<(ins Optional:$source)>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; let description = [{ Returns an initial storage specifier value. A storage specifier value holds the sizes for tensor dimensions, pointer arrays, index arrays, and the value array. + If this is a specifier for slices, it also holds the extra strides/offsets for tensor + dimensions. Example: ```mlir %0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR> + %1 = sparse_tensor.storage_specifier.init with %src + : !sparse_tensor.storage_specifier<#CSR> to + !sparse_tensor.storage_specifier<#CSR_SLICE> ``` }]; - let assemblyFormat = "attr-dict `:` qualified(type($result))"; + let builders = [ + OpBuilder<(ins "Type":$result), + [{ + build($_builder, $_state, result, Value()); + }]> + ]; + + let assemblyFormat = "attr-dict (`with` $source^)? `:` qualified(type($result))" + "(`to` qualified(type($source))^)?"; } def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get", [Pure]>, diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -577,6 +577,12 @@ } auto enc = md.getType().getEncoding(); + // TODO: + // if (mdKind == StorageSpecifierKind::DimOffset || + // mdKind == StorageSpecifierKind::DimStride) + // if (!enc.isSlice()) + // return op->emitError("requested slice data on non-slice tensor"); + ArrayRef dlts = enc.getDimLevelType(); unsigned rank = dlts.size(); @@ -665,6 +671,20 @@ return success(); } +LogicalResult ToSliceOffsetOp::verify() { + auto rank = getSlice().getType().cast().getRank(); + if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) + return emitError("requested dimension out of bound"); + return success(); +} + +LogicalResult ToSliceStrideOp::verify() { + auto rank = getSlice().getType().cast().getRank(); + if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) + return emitError("requested dimension out of bound"); + return success(); +} + LogicalResult GetStorageSpecifierOp::verify() { if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(), getSpecifier(), getOperation()))) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -20,18 +20,28 @@ // Helper methods. //===----------------------------------------------------------------------===// -static SmallVector getSpecifierFields(StorageSpecifierType tp) { +static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); unsigned rank = enc.getDimLevelType().size(); - SmallVector result; + SmallVector result; auto indexType = tp.getSizesType(); auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank); auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, getNumDataFieldsFromEncoding(enc)); result.push_back(dimSizes); result.push_back(memSizes); + + if (enc.isSlice()) { + // Extra fields are required for the slice information + auto dimOffset = LLVM::LLVMArrayType::get(ctx, indexType, rank); + auto dimStride = LLVM::LLVMArrayType::get(ctx, indexType, rank); + + result.push_back(dimOffset); + result.push_back(dimStride); + } + return result; } @@ -46,6 +56,8 @@ constexpr uint64_t kDimSizePosInSpecifier = 0; constexpr uint64_t kMemSizePosInSpecifier = 1; +constexpr uint64_t kDimOffsetPosInSpecifier = 2; +constexpr uint64_t kDimStridePosInSpecifier = 3; class SpecifierStructBuilder : public StructBuilder { public: @@ -54,34 +66,63 @@ } // Undef value for dimension sizes, all zero value for memory sizes. - static Value getInitValue(OpBuilder &builder, Location loc, Type structType); + static Value getInitValue(OpBuilder &builder, Location loc, Type structType, + Value source); + + Value dimOffset(OpBuilder &builder, Location loc, unsigned dim) const; + void setDimOffset(OpBuilder &builder, Location loc, unsigned dim, Value size); - Value dimSize(OpBuilder &builder, Location loc, unsigned dim); + Value dimSize(OpBuilder &builder, Location loc, unsigned dim) const; void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size); - Value memSize(OpBuilder &builder, Location loc, unsigned pos); + Value dimStride(OpBuilder &builder, Location loc, unsigned dim) const; + void setDimStride(OpBuilder &builder, Location loc, unsigned dim, Value size); + + Value memSize(OpBuilder &builder, Location loc, unsigned pos) const; void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size); + + Value memSizeArray(OpBuilder &builder, Location loc) const; + void setMemSizeArray(OpBuilder &builder, Location loc, Value array); }; Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, - Type structType) { + Type structType, Value source) { Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); - auto memSizeArrayType = structType.cast() - .getBody()[kMemSizePosInSpecifier] - .cast(); + if (!source) { + auto memSizeArrayType = structType.cast() + .getBody()[kMemSizePosInSpecifier] + .cast(); + + Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); + // Fill memSizes array with zero. + for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) + md.setMemSize(builder, loc, i, zero); + } else { + // We copies non-slice information (memory sizes array) from source + SpecifierStructBuilder sourceMd(source); + md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc)); + } + return md; +} - Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); - // Fill memSizes array with zero. - for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) - md.setMemSize(builder, loc, i, zero); +/// Builds IR extracting the pos-th offset from the descriptor. +Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, + unsigned dim) const { + return builder.create( + loc, value, ArrayRef({kDimOffsetPosInSpecifier, dim})); +} - return md; +/// Builds IR inserting the pos-th offset into the descriptor. +void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimOffsetPosInSpecifier, dim})); } /// Builds IR inserting the pos-th size into the descriptor. Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc, - unsigned dim) { + unsigned dim) const { return builder.create( loc, value, ArrayRef({kDimSizePosInSpecifier, dim})); } @@ -93,9 +134,23 @@ loc, value, size, ArrayRef({kDimSizePosInSpecifier, dim})); } +/// Builds IR extracting the pos-th stride from the descriptor. +Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc, + unsigned dim) const { + return builder.create( + loc, value, ArrayRef({kDimStridePosInSpecifier, dim})); +} + +/// Builds IR inserting the pos-th stride into the descriptor. +void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimStridePosInSpecifier, dim})); +} + /// Builds IR extracting the pos-th memory size into the descriptor. Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, - unsigned pos) { + unsigned pos) const { return builder.create( loc, value, ArrayRef({kMemSizePosInSpecifier, pos})); } @@ -107,6 +162,20 @@ loc, value, size, ArrayRef({kMemSizePosInSpecifier, pos})); } +/// Builds IR extracting the memory size array from the descriptor. +Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, + Location loc) const { + return builder.create(loc, value, + kMemSizePosInSpecifier); +} + +/// Builds IR inserting the memory size array into the descriptor. +void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, + Value array) { + value = builder.create(loc, value, array, + kMemSizePosInSpecifier); +} + } // namespace //===----------------------------------------------------------------------===// @@ -132,22 +201,40 @@ matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SpecifierStructBuilder spec(adaptor.getSpecifier()); - Value v; - if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) { - v = Base::onDimSize(rewriter, op, spec, - op.getDim().value().getZExtValue()); - } else { + switch (op.getSpecifierKind()) { + case StorageSpecifierKind::DimOffset: { + Value v = + Base::onDimOffset(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimSize: { + Value v = + Base::onDimSize(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimStride: { + Value v = + Base::onDimStride(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::IdxMemSize: + case StorageSpecifierKind::PtrMemSize: + case StorageSpecifierKind::ValMemSize: { auto enc = op.getSpecifier().getType().getEncoding(); StorageLayout layout(enc); Optional dim = std::nullopt; if (op.getDim()) - dim = op.getDim().value().getZExtValue(); + dim = (*op.getDim()).getZExtValue(); unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim); - v = Base::onMemSize(rewriter, op, spec, idx); + Value v = Base::onMemSize(rewriter, op, spec, idx); + rewriter.replaceOp(op, v); + return success(); } - - rewriter.replaceOp(op, v); - return success(); + } + llvm_unreachable("unrecognized specifer kind"); } }; @@ -155,12 +242,25 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + + static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimOffset(builder, op.getLoc(), d, op.getValue()); + return spec; + } + static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, unsigned d) { spec.setDimSize(builder, op.getLoc(), d, op.getValue()); return spec; } + static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimStride(builder, op.getLoc(), d, op.getValue()); + return spec; + } + static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, unsigned i) { spec.setMemSize(builder, op.getLoc(), i, op.getValue()); @@ -172,12 +272,24 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + + static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, unsigned d) { + return spec.dimOffset(builder, op.getLoc(), d); + } + static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op, - SpecifierStructBuilder &spec, unsigned d) { + const SpecifierStructBuilder &spec, unsigned d) { return spec.dimSize(builder, op.getLoc(), d); } + + static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, unsigned d) { + return spec.dimStride(builder, op.getLoc(), d); + } + static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, - SpecifierStructBuilder &spec, unsigned i) { + const SpecifierStructBuilder &spec, unsigned i) { return spec.memSize(builder, op.getLoc(), i); } }; @@ -190,8 +302,9 @@ matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue( - rewriter, op.getLoc(), llvmType)); + rewriter.replaceOp( + op, SpecifierStructBuilder::getInitValue( + rewriter, op.getLoc(), llvmType, adaptor.getSource())); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -116,6 +116,32 @@ // ----- +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + // expected-error@+1 {{requested dimension out of bound}} + %0 = sparse_tensor.slice.offset %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + // expected-error@+1 {{requested dimension out of bound}} + %0 = sparse_tensor.slice.stride %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { @@ -125,6 +151,17 @@ return %0 : i64 } +//// ----- +// +//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +// +//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { +// // _e_xpected-error@+1 {{requested slice data on non-slice tensor}} +// %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 +// : !sparse_tensor.storage_specifier<#SparseVector> to i64 +// return %0 : i64 +//} + // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -117,6 +117,38 @@ // ----- +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func @sparse_slice_offset( +// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.slice.offset %[[A]] at 1 : tensor<2x8xf64, #{{.*}}> +// CHECK: return %[[T]] : index +func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + %0 = sparse_tensor.slice.offset %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func @sparse_slice_stride( +// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.slice.stride %[[A]] at 1 : tensor<2x8xf64, #{{.*}}> +// CHECK: return %[[T]] : index +func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + %0 = sparse_tensor.slice.stride %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_metadata_init( @@ -129,6 +161,25 @@ // ----- +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#SparseVector_Slice = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_metadata_init( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.init with %[[A]] : +// CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}> +func.func @sparse_metadata_init(%src : !sparse_tensor.storage_specifier<#SparseVector>) + -> !sparse_tensor.storage_specifier<#SparseVector_Slice> { + %0 = sparse_tensor.storage_specifier.init with %src : !sparse_tensor.storage_specifier<#SparseVector> + to !sparse_tensor.storage_specifier<#SparseVector_Slice> + return %0 : !sparse_tensor.storage_specifier<#SparseVector_Slice> +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_get_md( @@ -141,6 +192,41 @@ return %0 : i64 } +// ----- + +#SparseVector_Slice = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_get_md( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_offset at 0 +// CHECK: return %[[T]] : i64 +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector_Slice>) -> i64 { + %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 + : !sparse_tensor.storage_specifier<#SparseVector_Slice> to i64 + return %0 : i64 +} + +// ----- + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"], + slice = [ (?, ?, ?) ] +}> + +// CHECK-LABEL: func @sparse_get_md( +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_stride at 0 +// CHECK: return %[[T]] : i64 +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { + %0 = sparse_tensor.storage_specifier.get %arg0 dim_stride at 0 + : !sparse_tensor.storage_specifier<#SparseVector> to i64 + return %0 : i64 +} + + // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>