diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -277,6 +277,12 @@ return forOp; } +/// Translates field index to memSizes index. +static unsigned getMemSizesIndex(unsigned field) { + assert(2 <= field); + return field - 2; +} + /// Creates a pushback op for given field and updates the fields array /// accordingly. static void createPushback(OpBuilder &builder, Location loc, @@ -286,9 +292,9 @@ Type etp = fields[field].getType().cast().getElementType(); if (value.getType() != etp) value = builder.create(loc, etp, value); - fields[field] = - builder.create(loc, fields[field].getType(), fields[1], - fields[field], value, APInt(64, field - 2)); + fields[field] = builder.create( + loc, fields[field].getType(), fields[1], fields[field], value, + APInt(64, getMemSizesIndex(field))); } /// Generates insertion code. @@ -739,6 +745,25 @@ } }; +/// Sparse codegen rule for number of entries operator. +class SparseNumberOfEntriesConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Query memSizes for the actually stored values size. + auto tuple = getTuple(adaptor.getTensor()); + auto fields = tuple.getInputs(); + unsigned lastField = fields.size() - 1; + Value field = + constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField)); + rewriter.replaceOpWithNewOp(op, fields[1], field); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -775,5 +800,6 @@ SparseExpandConverter, SparseCompressConverter, SparseInsertConverter, SparseToPointersConverter, SparseToIndicesConverter, SparseToValuesConverter, - SparseConvertConverter>(typeConverter, patterns.getContext()); + SparseConvertConverter, SparseNumberOfEntriesConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -205,6 +205,15 @@ params.push_back(ptr); } +/// Generates a call to obtain the values array. +static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp, + ValueRange ptr) { + SmallString<15> name{"sparseValues", + primaryTypeFunctionSuffix(tp.getElementType())}; + return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On) + .getResult(0); +} + /// Generates a call to release/delete a `SparseTensorCOO`. static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp, Value coo) { @@ -903,11 +912,28 @@ LogicalResult matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type resType = op.getType(); - Type eltType = resType.cast().getElementType(); - SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; - replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), - EmitCInterface::On); + auto resType = op.getType().cast(); + rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, + adaptor.getOperands())); + return success(); + } +}; + +/// Sparse conversion rule for number of entries operator. +class SparseNumberOfEntriesConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Query values array size for the actually stored values size. + Type eltType = op.getTensor().getType().cast().getElementType(); + auto resTp = MemRefType::get({ShapedType::kDynamicSize}, eltType); + Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, values, + constantIndex(rewriter, loc, 0)); return success(); } }; @@ -1250,9 +1276,10 @@ SparseTensorConcatConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorToPointersConverter, SparseTensorToIndicesConverter, SparseTensorToValuesConverter, - SparseTensorLoadConverter, SparseTensorInsertConverter, - SparseTensorExpandConverter, SparseTensorCompressConverter, - SparseTensorOutConverter>(typeConverter, patterns.getContext()); + SparseNumberOfEntriesConverter, SparseTensorLoadConverter, + SparseTensorInsertConverter, SparseTensorExpandConverter, + SparseTensorCompressConverter, SparseTensorOutConverter>( + typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext(), options); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -239,6 +239,20 @@ return %0 : memref } +// CHECK-LABEL: func @sparse_noe( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: return %[[NOE]] : index +func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { + %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> + return %0 : index +} + // CHECK-LABEL: func @sparse_dealloc_csr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -268,6 +268,17 @@ return %0 : memref } +// CHECK-LABEL: func @sparse_noe( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) +// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[T:.*]] = call @sparseValuesF64(%[[A]]) : (!llvm.ptr) -> memref +// CHECK: %[[NOE:.*]] = memref.dim %[[T]], %[[C]] : memref +// CHECK: return %[[NOE]] : index +func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { + %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> + return %0 : index +} + // CHECK-LABEL: func @sparse_reconstruct( // CHECK-SAME: %[[A:.*]]: !llvm.ptr // CHECK: return %[[A]] : !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir @@ -46,6 +46,16 @@ %1 = tensor.extract %0[] : tensor vector.print %1 : f32 + // Print number of entries in the sparse vectors. + // + // CHECK: 5 + // CHECK: 3 + // + %noe1 = sparse_tensor.number_of_entries %s1 : tensor<1024xf32, #SparseVector> + %noe2 = sparse_tensor.number_of_entries %s2 : tensor<1024xf32, #SparseVector> + vector.print %noe1 : index + vector.print %noe2 : index + // Release the resources. bufferization.dealloc_tensor %s1 : tensor<1024xf32, #SparseVector> bufferization.dealloc_tensor %s2 : tensor<1024xf32, #SparseVector>