diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -27,7 +27,19 @@ namespace { -/// Returns internal type encoding for overhead storage. +/// Internal encoding of primary storage. Keep this enum consistent +/// with the equivalent enum in the sparse runtime support library. +enum PrimaryTypeEnum : uint64_t { + kF64 = 1, + kF32 = 2, + kI64 = 3, + kI32 = 4, + kI16 = 5, + kI8 = 6 +}; + +/// Returns internal type encoding for overhead storage. Keep these +/// values consistent with the sparse runtime support library. static unsigned getOverheadTypeEncoding(unsigned width) { switch (width) { default: @@ -41,7 +53,8 @@ } } -/// Returns internal dimension level type encoding. +/// Returns internal dimension level type encoding. Keep these +/// values consistent with the sparse runtime support library. static unsigned getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { switch (dlt) { @@ -159,15 +172,17 @@ unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); unsigned primary; if (eltType.isF64()) - primary = 1; + primary = kF64; else if (eltType.isF32()) - primary = 2; + primary = kF32; + else if (eltType.isInteger(64)) + primary = kI64; else if (eltType.isInteger(32)) - primary = 3; + primary = kI32; else if (eltType.isInteger(16)) - primary = 4; + primary = kI16; else if (eltType.isInteger(8)) - primary = 5; + primary = kI8; else return failure(); params.push_back( @@ -256,6 +271,8 @@ name = "sparseValuesF64"; else if (eltType.isF32()) name = "sparseValuesF32"; + else if (eltType.isInteger(64)) + name = "sparseValuesI64"; else if (eltType.isInteger(32)) name = "sparseValuesI32"; else if (eltType.isInteger(16)) diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -129,6 +129,7 @@ // Primary storage. virtual void getValues(std::vector **) { fatal("valf64"); } virtual void getValues(std::vector **) { fatal("valf32"); } + virtual void getValues(std::vector **) { fatal("vali64"); } virtual void getValues(std::vector **) { fatal("vali32"); } virtual void getValues(std::vector **) { fatal("vali16"); } virtual void getValues(std::vector **) { fatal("vali8"); } @@ -437,6 +438,7 @@ TEMPLATE(MemRef1DU32, uint32_t); TEMPLATE(MemRef1DU16, uint16_t); TEMPLATE(MemRef1DU8, uint8_t); +TEMPLATE(MemRef1DI64, int64_t); TEMPLATE(MemRef1DI32, int32_t); TEMPLATE(MemRef1DI16, int16_t); TEMPLATE(MemRef1DI8, int8_t); @@ -448,9 +450,10 @@ enum PrimaryTypeEnum : uint64_t { kF64 = 1, kF32 = 2, - kI32 = 3, - kI16 = 4, - kI8 = 5 + kI64 = 3, + kI32 = 4, + kI16 = 5, + kI8 = 6 }; void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata, @@ -499,6 +502,7 @@ CASE(kU8, kU8, kF32, uint8_t, uint8_t, float); // Integral matrices with same overhead storage. + CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t); CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t); CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t); CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t); @@ -535,6 +539,7 @@ IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices) IMPL1(MemRef1DF64, sparseValuesF64, double, getValues) IMPL1(MemRef1DF32, sparseValuesF32, float, getValues) +IMPL1(MemRef1DI64, sparseValuesI64, int64_t, getValues) IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues) IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues) IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)