diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -57,6 +57,7 @@ #include "mlir/ExecutionEngine/SparseTensor/Storage.h" #include +#include #include using namespace mlir::sparse_tensor; @@ -211,6 +212,57 @@ *pIndices = indices; } +//===----------------------------------------------------------------------===// +// +// Utilities for manipulating `StridedMemRefType`. +// +//===----------------------------------------------------------------------===// + +#define ASSERT_NO_STRIDE(MEMREF) \ + do { \ + assert((MEMREF) && "Memref is nullptr"); \ + assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \ + } while (false) + +// All our functions use `uint64_t` for ranks, but `StridedMemRefType::sizes` +// uses `int64_t`. And we must make the cast explicit for the sake of +// `operator==`, or else it will generate a [-Wsign-compare] warning. +#define MEMREF_GET_USIZE(MEMREF) static_cast((MEMREF)->sizes[0]) + +#define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset) + +/// A version of `operator<` which is careful to ensure that negatives +/// are always considered less than positives regardless of the signedness +/// of the two types. +// +// This is a backport (to C++17) of C++20's `cmp_less_equal`, and +// is based on the sample implementation provided by the standard: +// . +template +static constexpr bool safelyLE(T t, U u) noexcept { + using UT = std::make_unsigned_t; + using UU = std::make_unsigned_t; + if constexpr (std::is_signed_v == std::is_signed_v) + return t <= u; + else if constexpr (std::is_signed_v) + return t < 0 ? true : static_cast(t) < u; + else + return u < 0 ? false : t < static_cast(u); +} + +// We make this a function rather than a macro mainly for type safety +// reasons. This function does not modify the vector, but it cannot +// be marked `const` because it is stored into the non-`const` memref. +template +static void vectorToMemref(std::vector &v, StridedMemRefType &ref) { + ref.basePtr = ref.data = v.data(); + ref.offset = 0; + assert(safelyLE(v.size(), std::numeric_limits::max()) && + "size overflow"); + ref.sizes[0] = static_cast(v.size()); + ref.strides[0] = 1; +} + } // anonymous namespace extern "C" { @@ -284,20 +336,21 @@ StridedMemRefType *lvl2dimRef, StridedMemRefType *dim2lvlRef, OverheadType ptrTp, OverheadType indTp, PrimaryType valTp, Action action, void *ptr) { - assert(dimSizesRef && dimSizesRef->strides[0] == 1); - assert(lvlSizesRef && lvlSizesRef->strides[0] == 1); - assert(lvlTypesRef && lvlTypesRef->strides[0] == 1); - assert(lvl2dimRef && lvl2dimRef->strides[0] == 1); - assert(dim2lvlRef && dim2lvlRef->strides[0] == 1); - const uint64_t dimRank = dimSizesRef->sizes[0]; - const uint64_t lvlRank = lvlSizesRef->sizes[0]; - assert(dim2lvlRef->sizes[0] == dimRank); - assert(lvlTypesRef->sizes[0] == lvlRank && lvl2dimRef->sizes[0] == lvlRank); - const index_type *dimSizes = dimSizesRef->data + dimSizesRef->offset; - const index_type *lvlSizes = lvlSizesRef->data + lvlSizesRef->offset; - const DimLevelType *lvlTypes = lvlTypesRef->data + lvlTypesRef->offset; - const index_type *lvl2dim = lvl2dimRef->data + lvl2dimRef->offset; - const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset; + ASSERT_NO_STRIDE(dimSizesRef); + ASSERT_NO_STRIDE(lvlSizesRef); + ASSERT_NO_STRIDE(lvlTypesRef); + ASSERT_NO_STRIDE(lvl2dimRef); + ASSERT_NO_STRIDE(dim2lvlRef); + const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef); + const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef); + assert(MEMREF_GET_USIZE(dim2lvlRef) == dimRank); + assert(MEMREF_GET_USIZE(lvlTypesRef) == lvlRank); + assert(MEMREF_GET_USIZE(lvl2dimRef) == lvlRank); + const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef); + const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef); + const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef); + const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); + const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. // This is safe because of the static_assert above. @@ -422,10 +475,8 @@ assert(ref &&tensor); \ std::vector *v; \ static_cast(tensor)->getValues(&v); \ - ref->basePtr = ref->data = v->data(); \ - ref->offset = 0; \ - ref->sizes[0] = v->size(); \ - ref->strides[0] = 1; \ + assert(v); \ + vectorToMemref(*v, *ref); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES) #undef IMPL_SPARSEVALUES @@ -436,10 +487,8 @@ assert(ref &&tensor); \ std::vector *v; \ static_cast(tensor)->LIB(&v, d); \ - ref->basePtr = ref->data = v->data(); \ - ref->offset = 0; \ - ref->sizes[0] = v->size(); \ - ref->strides[0] = 1; \ + assert(v); \ + vectorToMemref(*v, *ref); \ } #define IMPL_SPARSEPOINTERS(PNAME, P) \ IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers) @@ -460,16 +509,17 @@ void *lvlCOO, StridedMemRefType *vref, \ StridedMemRefType *dimIndRef, \ StridedMemRefType *dim2lvlRef) { \ - assert(lvlCOO &&vref &&dimIndRef &&dim2lvlRef); \ - assert(dimIndRef->strides[0] == 1 && dim2lvlRef->strides[0] == 1); \ - const uint64_t rank = dimIndRef->sizes[0]; \ - assert(dim2lvlRef->sizes[0] == rank); \ - const index_type *dimInd = dimIndRef->data + dimIndRef->offset; \ - const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset; \ + assert(lvlCOO &&vref); \ + ASSERT_NO_STRIDE(dimIndRef); \ + ASSERT_NO_STRIDE(dim2lvlRef); \ + const uint64_t rank = MEMREF_GET_USIZE(dimIndRef); \ + assert(MEMREF_GET_USIZE(dim2lvlRef) == rank); \ + const index_type *dimInd = MEMREF_GET_PAYLOAD(dimIndRef); \ + const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \ std::vector lvlInd(rank); \ for (uint64_t d = 0; d < rank; ++d) \ lvlInd[dim2lvl[d]] = dimInd[d]; \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ static_cast *>(lvlCOO)->add(lvlInd, *value); \ return lvlCOO; \ } @@ -480,11 +530,11 @@ bool _mlir_ciface_getNext##VNAME(void *iter, \ StridedMemRefType *iref, \ StridedMemRefType *vref) { \ - assert(iter &&iref &&vref); \ - assert(iref->strides[0] == 1); \ - index_type *indx = iref->data + iref->offset; \ - V *value = vref->data + vref->offset; \ - const uint64_t isize = iref->sizes[0]; \ + assert(iter &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indx = MEMREF_GET_PAYLOAD(iref); \ + V *value = MEMREF_GET_PAYLOAD(vref); \ + const uint64_t isize = MEMREF_GET_USIZE(iref); \ const Element *elem = \ static_cast *>(iter)->getNext(); \ if (elem == nullptr) \ @@ -501,11 +551,11 @@ void _mlir_ciface_lexInsert##VNAME(void *tensor, \ StridedMemRefType *cref, \ StridedMemRefType *vref) { \ - assert(tensor &&cref &&vref); \ - assert(cref->strides[0] == 1); \ - index_type *cursor = cref->data + cref->offset; \ + assert(tensor &&vref); \ + ASSERT_NO_STRIDE(cref); \ + index_type *cursor = MEMREF_GET_PAYLOAD(cref); \ assert(cursor); \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ static_cast(tensor)->lexInsert(cursor, *value); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) @@ -516,16 +566,16 @@ void *tensor, StridedMemRefType *cref, \ StridedMemRefType *vref, StridedMemRefType *fref, \ StridedMemRefType *aref, index_type count) { \ - assert(tensor &&cref &&vref &&fref &&aref); \ - assert(cref->strides[0] == 1); \ - assert(vref->strides[0] == 1); \ - assert(fref->strides[0] == 1); \ - assert(aref->strides[0] == 1); \ - assert(vref->sizes[0] == fref->sizes[0]); \ - index_type *cursor = cref->data + cref->offset; \ - V *values = vref->data + vref->offset; \ - bool *filled = fref->data + fref->offset; \ - index_type *added = aref->data + aref->offset; \ + assert(tensor); \ + ASSERT_NO_STRIDE(cref); \ + ASSERT_NO_STRIDE(vref); \ + ASSERT_NO_STRIDE(fref); \ + ASSERT_NO_STRIDE(aref); \ + assert(MEMREF_GET_USIZE(vref) == MEMREF_GET_USIZE(fref)); \ + index_type *cursor = MEMREF_GET_PAYLOAD(cref); \ + V *values = MEMREF_GET_PAYLOAD(vref); \ + bool *filled = MEMREF_GET_PAYLOAD(fref); \ + index_type *added = MEMREF_GET_PAYLOAD(aref); \ static_cast(tensor)->expInsert( \ cursor, values, filled, added, count); \ } @@ -644,9 +694,9 @@ void _mlir_ciface_getSparseTensorReaderDimSizes( void *p, StridedMemRefType *dref) { - assert(p && dref); - assert(dref->strides[0] == 1); - index_type *dimSizes = dref->data + dref->offset; + assert(p); + ASSERT_NO_STRIDE(dref); + index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); SparseTensorReader &file = *static_cast(p); const index_type *sizes = file.getDimSizes(); index_type rank = file.getRank(); @@ -662,12 +712,12 @@ void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ void *p, StridedMemRefType *iref, \ StridedMemRefType *vref) { \ - assert(p &&iref &&vref); \ - assert(iref->strides[0] == 1); \ - index_type *indices = iref->data + iref->offset; \ + assert(p &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indices = MEMREF_GET_PAYLOAD(iref); \ SparseTensorReader *stfile = static_cast(p); \ index_type rank = stfile->getRank(); \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ *value = stfile->readCOOElement(rank, indices); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) @@ -691,10 +741,10 @@ void _mlir_ciface_outSparseTensorWriterMetaData( void *p, index_type rank, index_type nnz, StridedMemRefType *dref) { - assert(p && dref); - assert(dref->strides[0] == 1); + assert(p); + ASSERT_NO_STRIDE(dref); assert(rank != 0); - index_type *dimSizes = dref->data + dref->offset; + index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); SparseTensorWriter &file = *static_cast(p); file << rank << " " << nnz << std::endl; for (index_type r = 0; r < rank - 1; ++r) @@ -706,13 +756,13 @@ void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ void *p, index_type rank, StridedMemRefType *iref, \ StridedMemRefType *vref) { \ - assert(p &&iref &&vref); \ - assert(iref->strides[0] == 1); \ - index_type *indices = iref->data + iref->offset; \ + assert(p &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indices = MEMREF_GET_PAYLOAD(iref); \ SparseTensorWriter &file = *static_cast(p); \ for (uint64_t r = 0; r < rank; ++r) \ file << (indices[r] + 1) << " "; \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ file << *value << std::endl; \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) @@ -720,4 +770,8 @@ } // extern "C" +#undef MEMREF_GET_PAYLOAD +#undef MEMREF_GET_USIZE +#undef ASSERT_NO_STRIDE + #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS