diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -937,6 +937,26 @@ } }; +/// Sparse codegen rule for accessing the linear indices buffer. +class SparseToIndicesBufferConverter + : public OpConversionPattern { +public: + using OpAdaptor = typename ToIndicesBufferOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToIndicesBufferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Replace the requested pointer access with corresponding field. + // The cast_op is inserted by type converter to intermix 1:N type + // conversion. + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + rewriter.replaceOp(op, desc.getAOSMemRef()); + + return success(); + } +}; + /// Sparse codegen rule for value accesses. class SparseToValuesConverter : public OpConversionPattern { public: @@ -1005,9 +1025,9 @@ SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, SparseInsertConverter, SparseToPointersConverter, SparseToIndicesConverter, - SparseToValuesConverter, SparseConvertConverter, - SparseNumberOfEntriesConverter>(typeConverter, - patterns.getContext()); + SparseToIndicesBufferConverter, SparseToValuesConverter, + SparseConvertConverter, SparseNumberOfEntriesConverter>( + typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext(), enableBufferInitialization); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -152,7 +152,8 @@ // TODO: The dim level property of the COO type relies on input tensors, the // shape relies on the output tensor // Helpers to setup a COO type. -static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) { +static RankedTensorType +getUnorderedCOOFromTypeWithOrdering(RankedTensorType src, AffineMap ordering) { auto *ctx = src.getContext(); auto rank = src.getRank(); SmallVector dims; @@ -176,12 +177,16 @@ // default value. unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0; unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0; - auto enc = SparseTensorEncodingAttr::get( - ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(), - pointerBitWidth, indexBitWidth); + auto enc = SparseTensorEncodingAttr::get(ctx, dims, ordering, AffineMap(), + pointerBitWidth, indexBitWidth); return RankedTensorType::get(src.getShape(), src.getElementType(), enc); } +static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) { + return getUnorderedCOOFromTypeWithOrdering( + src, AffineMap::getMultiDimIdentityMap(src.getRank(), src.getContext())); +} + /// Collects the dynamic dimension sizes for `tp` with the assumption that /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension /// sizes to dynSizes. @@ -771,6 +776,7 @@ RankedTensorType srcTp = src.getType().cast(); RankedTensorType dstTp = op.getType().cast(); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + int64_t rank = dstTp.getRank(); SmallVector srcSizes; sizesForTensor(rewriter, srcSizes, loc, srcTp, src); @@ -788,16 +794,21 @@ // the overhead types. SmallVector dynSrcSizes; getDynamicSizes(srcTp, srcSizes, dynSrcSizes); - srcTp = getUnorderedCOOFromType(srcTp); + srcTp = + getUnorderedCOOFromTypeWithOrdering(srcTp, encDst.getDimOrdering()); tmpCoo = rewriter.create(loc, srcTp, dynSrcSizes).getResult(); auto foreachOp = rewriter.create( loc, src, tmpCoo, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - // The resulting COO tensor has identity ordering. - auto t = builder.create(loc, v, reduc.front(), - args.slice(0, srcTp.getRank())); + SmallVector dstIndices(srcTp.getRank(), Value()); + for (int64_t i = 0; i < rank; i++) { + uint64_t dim = toStoredDim(encDst, i); + dstIndices[dim] = args[i]; + } + auto t = + builder.create(loc, v, reduc.front(), dstIndices); builder.create(loc, t); }); src = rewriter.create(loc, foreachOp.getResult(0), true); @@ -806,19 +817,6 @@ // Only need to sort if the srcTp is not already sorted (we faithfully take // the guarantee from the sparse tensor encoding). if (!isAllDimOrdered(srcTp)) { - // Sort the COO tensor so that its elements are ordered via increasing - // indices for the storage ordering of the dst tensor. - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - uint64_t rank = dstTp.getRank(); - uint64_t cooStart = getCOOStart(encSrc); - // Gather the indices-arrays in the dst tensor storage order. - SmallVector xs(rank, Value()); - for (uint64_t i = 0; i < rank; i++) { - uint64_t orgDim = toOrigDim(encSrc, i); - xs[toStoredDim(encDst, orgDim)] = - genToIndices(rewriter, loc, src, i, cooStart); - } - // Retrieve NNZ. Value nnz = rewriter.create(loc, src); nnz = rewriter.create(loc, rewriter.getIndexType(), @@ -826,9 +824,28 @@ // Retrieve the values-array. Value y = genToValues(rewriter, loc, src); - - // Sort the COO tensor. - rewriter.create(loc, nnz, xs, ValueRange{y}); + SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); + // Sort the COO tensor so that its elements are ordered via increasing + // indices for the storage ordering of the dst tensor. Use SortCoo if the + // COO tensor has the same dim ordering as the dst tensor. + if (rank > 1 && hasSameDimOrdering(srcTp, dstTp)) { + MemRefType indTp = + get1DMemRefType(getIndexOverheadType(rewriter, encSrc), + /*withLayout=*/false); + Value xs = rewriter.create(loc, indTp, src); + rewriter.create(loc, nnz, xs, ValueRange{y}, + rewriter.getIndexAttr(rank), + rewriter.getIndexAttr(0)); + } else { + // Gather the indices-arrays in the dst tensor storage order. + SmallVector xs(rank, Value()); + for (uint64_t i = 0; i < rank; i++) { + uint64_t orgDim = toOrigDim(encSrc, i); + xs[toStoredDim(encDst, orgDim)] = + genToIndices(rewriter, loc, src, i, /*cooStart=*/0); + } + rewriter.create(loc, nnz, xs, ValueRange{y}); + } } // For each element in the COO tensor, insert the element to the dst tensor. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -390,6 +390,13 @@ return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, idxDim); } + + Value getAOSMemRef() const { + auto enc = getSparseTensorEncoding(rType); + unsigned cooStart = getCOOStart(enc); + assert(cooStart < enc.getDimLevelType().size()); + return getIdxMemRef(cooStart); + } }; class SparseTensorDescriptor : public SparseTensorDescriptorImpl { diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -270,6 +270,19 @@ return %0 : memref> } +// CHECK-LABEL: func.func @sparse_indices_buffer_coo( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A3]] : memref +func.func @sparse_indices_buffer_coo(%arg0: tensor) -> memref { + %0 = sparse_tensor.indices_buffer %arg0 : tensor to memref + return %0 : memref +} + // CHECK-LABEL: func @sparse_noe( // CHECK-SAME: %[[A0:.*]]: memref, // CHECK-SAME: %[[A1:.*]]: memref, diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -122,11 +122,10 @@ // CHECK-RWT: sparse_tensor.yield %[[IFR]] // CHECK-RWT: } // CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T2]] hasInserts -// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} -// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} // CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] +// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]] +// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index} // CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) // CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor @@ -182,11 +181,10 @@ // CHECK-RWT: sparse_tensor.yield %[[L0T2]] // CHECK-RWT: } // CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts -// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} -// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} // CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] // CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] +// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]] +// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index} // CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() // CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) // CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor