diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -50,6 +50,7 @@ #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h" #include "mlir/ExecutionEngine/SparseTensor/COO.h" #include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h" #include "mlir/ExecutionEngine/SparseTensor/File.h" @@ -213,6 +214,47 @@ *pIndices = indices; } +//===----------------------------------------------------------------------===// +// +// Utilities for manipulating `StridedMemRefType`. +// +//===----------------------------------------------------------------------===// + +// We shouldn't need to use `detail::safelyEQ` here since the `1` is a literal. +#define ASSERT_NO_STRIDE(MEMREF) \ + do { \ + assert((MEMREF) && "Memref is nullptr"); \ + assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \ + } while (false) + +// All our functions use `uint64_t` for ranks, but `StridedMemRefType::sizes` +// uses `int64_t` on some platforms. So we explicitly cast this lookup to +// ensure we get a consistent type, and we use `checkOverflowCast` rather +// than `static_cast` just to be extremely sure that the casting can't +// go awry. (The cast should aways be safe since (1) sizes should never +// be negative, and (2) the maximum `int64_t` is smaller than the maximum +// `uint64_t`. But it's better to be safe than sorry.) +#define MEMREF_GET_USIZE(MEMREF) \ + detail::checkOverflowCast((MEMREF)->sizes[0]) + +#define ASSERT_USIZE_EQ(MEMREF, SZ) \ + assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \ + "Memref size mismatch") + +#define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset) + +// We make this a function rather than a macro mainly for type safety +// reasons. This function does not modify the vector, but it cannot +// be marked `const` because it is stored into the non-`const` memref. +template +static void vectorToMemref(std::vector &v, StridedMemRefType &ref) { + ref.basePtr = ref.data = v.data(); + ref.offset = 0; + using SizeT = typename std::remove_reference_t; + ref.sizes[0] = detail::checkOverflowCast(v.size()); + ref.strides[0] = 1; +} + } // anonymous namespace extern "C" { @@ -286,21 +328,21 @@ StridedMemRefType *lvl2dimRef, StridedMemRefType *dim2lvlRef, OverheadType ptrTp, OverheadType indTp, PrimaryType valTp, Action action, void *ptr) { - assert(dimSizesRef && dimSizesRef->strides[0] == 1); - assert(lvlSizesRef && lvlSizesRef->strides[0] == 1); - assert(lvlTypesRef && lvlTypesRef->strides[0] == 1); - assert(lvl2dimRef && lvl2dimRef->strides[0] == 1); - assert(dim2lvlRef && dim2lvlRef->strides[0] == 1); - const uint64_t dimRank = dimSizesRef->sizes[0]; - const uint64_t lvlRank = lvlSizesRef->sizes[0]; - assert(dim2lvlRef->sizes[0] == (int64_t)dimRank); - assert(lvlTypesRef->sizes[0] == (int64_t)lvlRank && - lvl2dimRef->sizes[0] == (int64_t)lvlRank); - const index_type *dimSizes = dimSizesRef->data + dimSizesRef->offset; - const index_type *lvlSizes = lvlSizesRef->data + lvlSizesRef->offset; - const DimLevelType *lvlTypes = lvlTypesRef->data + lvlTypesRef->offset; - const index_type *lvl2dim = lvl2dimRef->data + lvl2dimRef->offset; - const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset; + ASSERT_NO_STRIDE(dimSizesRef); + ASSERT_NO_STRIDE(lvlSizesRef); + ASSERT_NO_STRIDE(lvlTypesRef); + ASSERT_NO_STRIDE(lvl2dimRef); + ASSERT_NO_STRIDE(dim2lvlRef); + const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef); + const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef); + ASSERT_USIZE_EQ(dim2lvlRef, dimRank); + ASSERT_USIZE_EQ(lvlTypesRef, lvlRank); + ASSERT_USIZE_EQ(lvl2dimRef, lvlRank); + const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef); + const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef); + const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef); + const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); + const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. // This is safe because of the static_assert above. @@ -425,10 +467,8 @@ assert(ref &&tensor); \ std::vector *v; \ static_cast(tensor)->getValues(&v); \ - ref->basePtr = ref->data = v->data(); \ - ref->offset = 0; \ - ref->sizes[0] = v->size(); \ - ref->strides[0] = 1; \ + assert(v); \ + vectorToMemref(*v, *ref); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES) #undef IMPL_SPARSEVALUES @@ -439,10 +479,8 @@ assert(ref &&tensor); \ std::vector *v; \ static_cast(tensor)->LIB(&v, d); \ - ref->basePtr = ref->data = v->data(); \ - ref->offset = 0; \ - ref->sizes[0] = v->size(); \ - ref->strides[0] = 1; \ + assert(v); \ + vectorToMemref(*v, *ref); \ } #define IMPL_SPARSEPOINTERS(PNAME, P) \ IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers) @@ -463,16 +501,17 @@ void *lvlCOO, StridedMemRefType *vref, \ StridedMemRefType *dimIndRef, \ StridedMemRefType *dim2lvlRef) { \ - assert(lvlCOO &&vref &&dimIndRef &&dim2lvlRef); \ - assert(dimIndRef->strides[0] == 1 && dim2lvlRef->strides[0] == 1); \ - const uint64_t rank = dimIndRef->sizes[0]; \ - assert(dim2lvlRef->sizes[0] == (int64_t)rank); \ - const index_type *dimInd = dimIndRef->data + dimIndRef->offset; \ - const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset; \ + assert(lvlCOO &&vref); \ + ASSERT_NO_STRIDE(dimIndRef); \ + ASSERT_NO_STRIDE(dim2lvlRef); \ + const uint64_t rank = MEMREF_GET_USIZE(dimIndRef); \ + ASSERT_USIZE_EQ(dim2lvlRef, rank); \ + const index_type *dimInd = MEMREF_GET_PAYLOAD(dimIndRef); \ + const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \ std::vector lvlInd(rank); \ for (uint64_t d = 0; d < rank; ++d) \ lvlInd[dim2lvl[d]] = dimInd[d]; \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ static_cast *>(lvlCOO)->add(lvlInd, *value); \ return lvlCOO; \ } @@ -483,11 +522,11 @@ bool _mlir_ciface_getNext##VNAME(void *iter, \ StridedMemRefType *iref, \ StridedMemRefType *vref) { \ - assert(iter &&iref &&vref); \ - assert(iref->strides[0] == 1); \ - index_type *indx = iref->data + iref->offset; \ - V *value = vref->data + vref->offset; \ - const uint64_t isize = iref->sizes[0]; \ + assert(iter &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indx = MEMREF_GET_PAYLOAD(iref); \ + V *value = MEMREF_GET_PAYLOAD(vref); \ + const uint64_t isize = MEMREF_GET_USIZE(iref); \ const Element *elem = \ static_cast *>(iter)->getNext(); \ if (elem == nullptr) \ @@ -504,11 +543,11 @@ void _mlir_ciface_lexInsert##VNAME(void *tensor, \ StridedMemRefType *cref, \ StridedMemRefType *vref) { \ - assert(tensor &&cref &&vref); \ - assert(cref->strides[0] == 1); \ - index_type *cursor = cref->data + cref->offset; \ + assert(tensor &&vref); \ + ASSERT_NO_STRIDE(cref); \ + index_type *cursor = MEMREF_GET_PAYLOAD(cref); \ assert(cursor); \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ static_cast(tensor)->lexInsert(cursor, *value); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) @@ -519,16 +558,16 @@ void *tensor, StridedMemRefType *cref, \ StridedMemRefType *vref, StridedMemRefType *fref, \ StridedMemRefType *aref, index_type count) { \ - assert(tensor &&cref &&vref &&fref &&aref); \ - assert(cref->strides[0] == 1); \ - assert(vref->strides[0] == 1); \ - assert(fref->strides[0] == 1); \ - assert(aref->strides[0] == 1); \ - assert(vref->sizes[0] == fref->sizes[0]); \ - index_type *cursor = cref->data + cref->offset; \ - V *values = vref->data + vref->offset; \ - bool *filled = fref->data + fref->offset; \ - index_type *added = aref->data + aref->offset; \ + assert(tensor); \ + ASSERT_NO_STRIDE(cref); \ + ASSERT_NO_STRIDE(vref); \ + ASSERT_NO_STRIDE(fref); \ + ASSERT_NO_STRIDE(aref); \ + ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref)); \ + index_type *cursor = MEMREF_GET_PAYLOAD(cref); \ + V *values = MEMREF_GET_PAYLOAD(vref); \ + bool *filled = MEMREF_GET_PAYLOAD(fref); \ + index_type *added = MEMREF_GET_PAYLOAD(aref); \ static_cast(tensor)->expInsert( \ cursor, values, filled, added, count); \ } @@ -537,13 +576,13 @@ void _mlir_ciface_getSparseTensorReaderDimSizes( void *p, StridedMemRefType *dref) { - assert(p && dref); - assert(dref->strides[0] == 1); - index_type *dimSizes = dref->data + dref->offset; + assert(p); + ASSERT_NO_STRIDE(dref); + index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); SparseTensorReader &file = *static_cast(p); const index_type *sizes = file.getDimSizes(); index_type rank = file.getRank(); - for (uint64_t r = 0; r < rank; ++r) + for (index_type r = 0; r < rank; ++r) dimSizes[r] = sizes[r]; } @@ -551,12 +590,12 @@ void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ void *p, StridedMemRefType *iref, \ StridedMemRefType *vref) { \ - assert(p &&iref &&vref); \ - assert(iref->strides[0] == 1); \ - index_type *indices = iref->data + iref->offset; \ + assert(p &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indices = MEMREF_GET_PAYLOAD(iref); \ SparseTensorReader *stfile = static_cast(p); \ index_type rank = stfile->getRank(); \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ *value = stfile->readCOOElement(rank, indices); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) @@ -565,10 +604,10 @@ void _mlir_ciface_outSparseTensorWriterMetaData( void *p, index_type rank, index_type nnz, StridedMemRefType *dref) { - assert(p && dref); - assert(dref->strides[0] == 1); + assert(p); + ASSERT_NO_STRIDE(dref); assert(rank != 0); - index_type *dimSizes = dref->data + dref->offset; + index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); SparseTensorWriter &file = *static_cast(p); file << rank << " " << nnz << std::endl; for (index_type r = 0; r < rank - 1; ++r) @@ -580,13 +619,13 @@ void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ void *p, index_type rank, StridedMemRefType *iref, \ StridedMemRefType *vref) { \ - assert(p &&iref &&vref); \ - assert(iref->strides[0] == 1); \ - index_type *indices = iref->data + iref->offset; \ + assert(p &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indices = MEMREF_GET_PAYLOAD(iref); \ SparseTensorWriter &file = *static_cast(p); \ - for (uint64_t r = 0; r < rank; ++r) \ + for (index_type r = 0; r < rank; ++r) \ file << (indices[r] + 1) << " "; \ - V *value = vref->data + vref->offset; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ file << *value << std::endl; \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) @@ -723,4 +762,9 @@ } // extern "C" +#undef MEMREF_GET_PAYLOAD +#undef ASSERT_USIZE_EQ +#undef MEMREF_GET_USIZE +#undef ASSERT_NO_STRIDE + #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS