diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -277,6 +277,12 @@ return forOp; } +/// Translate field index to memSizes index. +static unsigned getMemSizesIndex(unsigned field) { + assert(2 <= field); + return field - 2; +} + /// Creates a pushback op for given field and updates the fields array /// accordingly. static void createPushback(OpBuilder &builder, Location loc, @@ -288,7 +294,8 @@ value = builder.create(loc, etp, value); fields[field] = builder.create(loc, fields[field].getType(), fields[1], - fields[field], value, APInt(64, field - 2)); + fields[field], value, + APInt(64, getMemSizesIndex(field))); } /// Generates insertion code. @@ -739,6 +746,24 @@ } }; +/// Sparse codegen rule for number of entries operator. +class SparseNumberOfEntriesConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Query memSizes for the actually stored values size. + auto tuple = getTuple(adaptor.getTensor()); + auto fields = tuple.getInputs(); + unsigned lastField = fields.size() - 1; + Value field = constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField)); + rewriter.replaceOpWithNewOp(op, fields[1], field); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -775,5 +800,6 @@ SparseExpandConverter, SparseCompressConverter, SparseInsertConverter, SparseToPointersConverter, SparseToIndicesConverter, SparseToValuesConverter, - SparseConvertConverter>(typeConverter, patterns.getContext()); + SparseConvertConverter, SparseNumberOfEntriesConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1029,6 +1029,28 @@ } }; +/// Sparse codegen rule for number of entries operator. +class SparseNumberOfEntriesConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Query values array size for the actually stored values size. + Location loc = op->getLoc(); + Type eltType = op.getTensor().getType().cast().getElementType(); + Type resTp = MemRefType::get({ShapedType::kDynamicSize}, eltType); + SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; + Value values = createFuncCall(rewriter, loc, name, resTp, + adaptor.getOperands(), EmitCInterface::On) + .getResult(0); + rewriter.replaceOpWithNewOp(op, values, + constantIndex(rewriter, loc, 0)); + return success(); + } +}; + /// Sparse conversion rule for tensor rematerialization. class SparseTensorLoadConverter : public OpConversionPattern { public: @@ -1367,9 +1389,10 @@ SparseTensorConcatConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorToPointersConverter, SparseTensorToIndicesConverter, SparseTensorToValuesConverter, - SparseTensorLoadConverter, SparseTensorInsertConverter, - SparseTensorExpandConverter, SparseTensorCompressConverter, - SparseTensorOutConverter>(typeConverter, patterns.getContext()); + SparseNumberOfEntriesConverter, SparseTensorLoadConverter, + SparseTensorInsertConverter, SparseTensorExpandConverter, + SparseTensorCompressConverter, SparseTensorOutConverter>( + typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext(), options); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -239,6 +239,20 @@ return %0 : memref } +// CHECK-LABEL: func @sparse_noe( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: return %[[NOE]] : index +func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { + %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> + return %0 : index +} + // CHECK-LABEL: func @sparse_dealloc_csr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -268,6 +268,17 @@ return %0 : memref } +// CHECK-LABEL: func @sparse_noe( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) +// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[T:.*]] = call @sparseValuesF64(%[[A]]) : (!llvm.ptr) -> memref +// CHECK: %[[NOE:.*]] = memref.dim %[[T]], %[[C]] : memref +// CHECK: return %[[NOE]] : index +func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { + %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> + return %0 : index +} + // CHECK-LABEL: func @sparse_reconstruct( // CHECK-SAME: %[[A:.*]]: !llvm.ptr // CHECK: return %[[A]] : !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir @@ -46,6 +46,16 @@ %1 = tensor.extract %0[] : tensor vector.print %1 : f32 + // Print number of entries in the sparse vectors. + // + // CHECK: 5 + // CHECK: 3 + // + %noe1 = sparse_tensor.number_of_entries %s1 : tensor<1024xf32, #SparseVector> + %noe2 = sparse_tensor.number_of_entries %s2 : tensor<1024xf32, #SparseVector> + vector.print %noe1 : index + vector.print %noe2 : index + // Release the resources. bufferization.dealloc_tensor %s1 : tensor<1024xf32, #SparseVector> bufferization.dealloc_tensor %s2 : tensor<1024xf32, #SparseVector>