diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -51,6 +51,38 @@ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; } +def SparseTensor_ConvertOp : SparseTensor_Op<"convert", [SameOperandsAndResultType]>, + Arguments<(ins AnyTensor:$source)>, + Results<(outs AnyTensor:$dest)> { + string summary = "Converts between different tensor types"; + string description = [{ + Converts one sparse or dense tensor type to another tensor type. The rank + and dimensions of the source and destination types must match exactly, + only the sparse encoding of these types may be different. The name `convert` + was preferred over `cast`, since the operation may incur a non-trivial cost. + + When converting between two different sparse tensor types, only explicitly + stored values are moved from one underlying sparse storage format to + the other. When converting from an unannotated dense tensor type to a + sparse tensor type, an explicit test for nonzero values is used. When + converting to an unannotated dense tensor type, implicit zeroes in the + sparse storage format are made explicit. Note that the conversions can have + non-trivial costs associated with them, since they may involve elaborate + data structure transformations. Also, conversions from sparse tensor types + into dense tensor types may be infeasible in terms of storage requirements. + + Examples: + + ```mlir + %0 = sparse_tensor.convert %1 : tensor<32x32xf32> to tensor<32x32xf32, #CSR> + + %2 = sparse_tensor.convert %3 : tensor<8x8xi32, #CSC> to tensor<8x8xi32, #CSR> + ``` + + }]; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; +} + def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>, Arguments<(ins AnyTensor:$tensor, Index:$dim)>, Results<(outs AnyStridedMemRefOfRank<1>:$result)> { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -208,11 +208,27 @@ } static LogicalResult verify(NewOp op) { - if (!getSparseTensorEncoding(op.getResult().getType())) + if (!getSparseTensorEncoding(op.result().getType())) return op.emitError("expected a sparse tensor result"); return success(); } +static LogicalResult verify(ConvertOp op) { + if (auto tp1 = op.source().getType().dyn_cast()) { + if (auto tp2 = op.dest().getType().dyn_cast()) { + assert(tp1.getRank() == tp2.getRank()); + auto shape1 = tp1.getShape(); + auto shape2 = tp2.getShape(); + for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) + if (shape1[d] != shape2[d]) + return op.emitError() + << "unexpected conversion mismatch in dimension " << d; + return success(); + } + } + return op.emitError("unexpected type in convert"); +} + static LogicalResult verify(ToPointersOp op) { if (auto e = getSparseTensorEncoding(op.tensor().getType())) { if (failed(isInBounds(op.dim(), op.tensor()))) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -27,16 +27,23 @@ namespace { -/// Internal encoding of primary storage. Keep this enum consistent -/// with the equivalent enum in the sparse runtime support library. -enum PrimaryTypeEnum : uint64_t { - kF64 = 1, - kF32 = 2, - kI64 = 3, - kI32 = 4, - kI16 = 5, - kI8 = 6 -}; +/// Returns internal type encoding for primary storage. Keep these +/// values consistent with the sparse runtime support library. +static unsigned getPrimaryTypeEncoding(Type tp) { + if (tp.isF64()) + return 1; + if (tp.isF32()) + return 2; + if (tp.isInteger(64)) + return 3; + if (tp.isInteger(32)) + return 4; + if (tp.isInteger(16)) + return 5; + if (tp.isInteger(8)) + return 6; + return 0; +} /// Returns internal type encoding for overhead storage. Keep these /// values consistent with the sparse runtime support library. @@ -170,20 +177,8 @@ // Secondary and primary types encoding. unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); - unsigned primary; - if (eltType.isF64()) - primary = kF64; - else if (eltType.isF32()) - primary = kF32; - else if (eltType.isInteger(64)) - primary = kI64; - else if (eltType.isInteger(32)) - primary = kI32; - else if (eltType.isInteger(16)) - primary = kI16; - else if (eltType.isInteger(8)) - primary = kI8; - else + unsigned primary = getPrimaryTypeEncoding(eltType); + if (!primary) return failure(); params.push_back( rewriter.create(loc, rewriter.getI64IntegerAttr(secPtr))); @@ -200,6 +195,17 @@ } }; +/// Sparse conversion rule for the convert operator. +class SparseTensorConvertConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ConvertOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // TODO: implement conversions lowering + return failure(); + } +}; + /// Sparse conversion rule for pointer accesses. class SparseTensorToPointersConverter : public OpConversionPattern { @@ -324,8 +330,8 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, - patterns.getContext()); + SparseTensorNewConverter, SparseTensorConvertConverter, + SparseTensorToPointersConverter, SparseTensorToIndicesConverter, + SparseTensorToValuesConverter, SparseTensorToTensorConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -111,3 +111,21 @@ %0 = sparse_tensor.tensor %arg0 : memref to tensor<16x32xf64> return %0 : tensor<16x32xf64> } + +// ----- + +func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> { + // expected-error@+1 {{unexpected type in convert}} + %0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32> + return %0 : tensor<10xf32> +} + +// ----- + +#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> + +func @sparse_convert_mismatch(%arg0: tensor<10x10xf32>) -> tensor<10x?xf32, #CSR> { + // expected-error@+1 {{unexpected conversion mismatch in dimension 1}} + %0 = sparse_tensor.convert %arg0 : tensor<10x10xf32> to tensor<10x?xf32, #CSR> + return %0 : tensor<10x?xf32, #CSR> +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -15,6 +15,32 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +// CHECK-LABEL: func @sparse_convert_1d_to_sparse( +// CHECK-SAME: %[[A:.*]]: tensor<64xf32>) +// CHECK: %[[T:.*]] = sparse_tensor.convert %[[A]] : tensor<64xf32> to tensor<64xf32, #{{.*}}> +// CHECK: return %[[T]] : tensor<64xf32, #{{.*}}> +func @sparse_convert_1d_to_sparse(%arg0: tensor<64xf32>) -> tensor<64xf32, #SparseVector> { + %0 = sparse_tensor.convert %arg0 : tensor<64xf32> to tensor<64xf32, #SparseVector> + return %0 : tensor<64xf32, #SparseVector> +} + +// ----- + +#SparseTensor = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ] }> + +// CHECK-LABEL: func @sparse_convert_3d_from_sparse( +// CHECK-SAME: %[[A:.*]]: tensor<8x8x8xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.convert %[[A]] : tensor<8x8x8xf64, #{{.*}}> to tensor<8x8x8xf64> +// CHECK: return %[[T]] : tensor<8x8x8xf64> +func @sparse_convert_3d_from_sparse(%arg0: tensor<8x8x8xf64, #SparseTensor>) -> tensor<8x8x8xf64> { + %0 = sparse_tensor.convert %arg0 : tensor<8x8x8xf64, #SparseTensor> to tensor<8x8x8xf64> + return %0 : tensor<8x8x8xf64> +} + +// ----- + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + // CHECK-LABEL: func @sparse_pointers( // CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>) // CHECK: %[[C:.*]] = constant 0 : index