diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -33,16 +33,6 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Reorders stored dimension to original dimension. -static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) { - auto order = enc.getDimOrdering(); - if (order) { - assert(order.isPermutation()); - return order.getDimPosition(i); - } - return i; -} - /// Reorders original dimension to stored dimension. static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) { auto order = enc.getDimOrdering(); @@ -67,7 +57,6 @@ Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; Type eltType = rType.getElementType(); - ArrayRef shape = rType.getShape(); // // Sparse tensor storage for rank-dimensional tensor is organized as a // single compound type with the following fields: @@ -85,27 +74,18 @@ // memref values ; values // }; // - int64_t linear = 1; - bool allDense = true; unsigned rank = rType.getShape().size(); SmallVector fields; // The dimSizes array. fields.push_back(MemRefType::get({rank}, indexType)); // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { - // Get the original dimension (ro) for the current stored dimension (r). - unsigned ro = toOrig(enc, r); // Dimension level types apply in order to the reordered dimension. // As a result, the compound type can be constructed directly in the given // order. Clients of this type know what field is what from the sparse // tensor type. switch (enc.getDimLevelType()[r]) { case SparseTensorEncodingAttr::DimLevelType::Dense: - // Linearize the size of consecutive dense dimensions. - if (ShapedType::isDynamic(shape[ro]) || ShapedType::isDynamic(linear)) - linear = ShapedType::kDynamicSize; - else - linear *= shape[ro]; break; case SparseTensorEncodingAttr::DimLevelType::Compressed: case SparseTensorEncodingAttr::DimLevelType::CompressedNu: @@ -113,23 +93,17 @@ case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType)); fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); - allDense = false; - linear = 1; break; case SparseTensorEncodingAttr::DimLevelType::Singleton: case SparseTensorEncodingAttr::DimLevelType::SingletonNu: case SparseTensorEncodingAttr::DimLevelType::SingletonNo: case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); - allDense = false; - linear = 1; break; } } // The values array. - int64_t nnz = - (rType.hasStaticShape() && allDense) ? linear : ShapedType::kDynamicSize; - fields.push_back(MemRefType::get({nnz}, eltType)); + fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType)); // Sparse tensor storage (temporarily) lives in a tuple. This allows a // simple 1:1 type conversion during codegen. A subsequent pass uses // a 1:N type conversion to expand the tuple into its fields. @@ -241,6 +215,23 @@ } }; +/// Sparse codegen rule for trivial tensor casts. +class SparseCastConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only rewrite identically annotated source/dest. + auto encDst = getSparseTensorEncoding(op.getType()); + auto encSrc = getSparseTensorEncoding(op.getSource().getType()); + if (!encDst || encDst != encSrc) + return failure(); + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; + /// Sparse conversion rule for pointer accesses. class SparseToPointersConverter : public OpConversionPattern { public: @@ -314,7 +305,7 @@ /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -36,12 +36,28 @@ }> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -> tuple, memref, memref, memref> +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) // CHECK: return %[[A]] : tuple, memref, memref, memref> func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } +// CHECK-LABEL: func @sparse_nop_cast( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) +// CHECK: return %[[A]] : tuple, memref, memref, memref> +func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { + %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor + return %0 : tensor +} + +// CHECK-LABEL: func @sparse_nop_cast_3d( +// CHECK-SAME: %[[A:.*]]: tuple, memref>) +// CHECK: return %[[A]] : tuple, memref> +func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { + %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor + return %0 : tensor +} + // CHECK-LABEL: func @sparse_dense_2d( // CHECK-SAME: %[[A:.*]]: tuple, memref>) func.func @sparse_dense_2d(%arg0: tensor) { @@ -71,7 +87,7 @@ // fold using the original static dimension sizes. // // CHECK-LABEL: func @sparse_dense_3d( -// CHECK-SAME: %[[A:.*]]: tuple, memref<6000xf64>>) -> index { +// CHECK-SAME: %[[A:.*]]: tuple, memref>) // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -86,7 +102,7 @@ // since the latter honors the dimOrdering. // // CHECK-LABEL: func @sparse_dense_3d_dyn( -// CHECK-SAME: %[[A:.*]]: tuple, memref>) -> index { +// CHECK-SAME: %[[A:.*]]: tuple, memref>) // CHECK: %[[C:.*]] = arith.constant 2 : index // CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple, memref> to memref<3xindex> // CHECK: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>