diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h @@ -0,0 +1,55 @@ +//===- SparseUtils.h - Enums shared with the runtime ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines several enums shared between +// Transforms/SparseTensorConversion.cpp and ExecutionEngine/SparseUtils.cpp +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_SPARSEUTILS_H_ +#define MLIR_EXECUTIONENGINE_SPARSEUTILS_H_ + +#include + +extern "C" { + +/// Encoding of the elemental type, for "overloading" @newSparseTensor. +enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; + +/// Encoding of the elemental type, for "overloading" @newSparseTensor. +enum class PrimaryType : uint32_t { + kF64 = 1, + kF32 = 2, + kI64 = 3, + kI32 = 4, + kI16 = 5, + kI8 = 6 +}; + +/// The actions performed by @newSparseTensor. +enum class Action : uint32_t { + kEmpty = 0, + kFromFile = 1, + kFromCOO = 2, + kEmptyCOO = 3, + kToCOO = 4, + kToIterator = 5 +}; + +/// This enum mimics `SparseTensorEncodingAttr::DimLevelType` for +/// breaking dependency cycles. `SparseTensorEncodingAttr::DimLevelType` +/// is the source of truth and this enum should be kept consistent with it. +enum class DimLevelType : uint8_t { + kDense = 0, + kCompressed = 1, + kSingleton = 2 +}; + +} // extern "C" + +#endif // MLIR_EXECUTIONENGINE_SPARSEUTILS_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/ExecutionEngine/SparseTensorUtils.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -29,69 +30,10 @@ namespace { -/// New tensor storage action. Keep these values consistent with -/// the sparse runtime support library. -enum Action : uint32_t { - kEmpty = 0, - kFromFile = 1, - kFromCOO = 2, - kEmptyCOO = 3, - kToCOO = 4, - kToIter = 5 -}; - //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// -/// Returns internal type encoding for primary storage. Keep these -/// values consistent with the sparse runtime support library. -static uint32_t getPrimaryTypeEncoding(Type tp) { - if (tp.isF64()) - return 1; - if (tp.isF32()) - return 2; - if (tp.isInteger(64)) - return 3; - if (tp.isInteger(32)) - return 4; - if (tp.isInteger(16)) - return 5; - if (tp.isInteger(8)) - return 6; - return 0; -} - -/// Returns internal type encoding for overhead storage. Keep these -/// values consistent with the sparse runtime support library. -static uint32_t getOverheadTypeEncoding(unsigned width) { - switch (width) { - default: - return 1; - case 32: - return 2; - case 16: - return 3; - case 8: - return 4; - } -} - -/// Returns internal dimension level type encoding. Keep these -/// values consistent with the sparse runtime support library. -static uint32_t -getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { - switch (dlt) { - case SparseTensorEncodingAttr::DimLevelType::Dense: - return 0; - case SparseTensorEncodingAttr::DimLevelType::Compressed: - return 1; - case SparseTensorEncodingAttr::DimLevelType::Singleton: - return 2; - } - llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); -} - /// Generates a constant zero of the given type. inline static Value constantZero(ConversionPatternRewriter &rewriter, Location loc, Type t) { @@ -116,6 +58,68 @@ return rewriter.create(loc, i, 8); } +/// Returns a constant of the internal type encoding for overhead storage. +static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, + Location loc, unsigned width) { + OverheadType sec; + switch (width) { + default: + sec = OverheadType::kU64; + break; + case 32: + sec = OverheadType::kU32; + break; + case 16: + sec = OverheadType::kU16; + break; + case 8: + sec = OverheadType::kU8; + break; + } + return constantI32(rewriter, loc, (uint32_t)sec); +} + +/// Returns a constant of the internal type encoding for primary storage. +static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, + Location loc, Type tp) { + PrimaryType primary; + if (tp.isF64()) { + primary = PrimaryType::kF64; + } else if (tp.isF32()) { + primary = PrimaryType::kF32; + } else if (tp.isInteger(64)) { + primary = PrimaryType::kI64; + } else if (tp.isInteger(32)) { + primary = PrimaryType::kI32; + } else if (tp.isInteger(16)) { + primary = PrimaryType::kI16; + } else if (tp.isInteger(8)) { + primary = PrimaryType::kI8; + } else { + llvm_unreachable("Unknown element type"); + } + return constantI32(rewriter, loc, (uint32_t)primary); +} + +/// Returns a constant of the internal dimension level type encoding. +static Value +constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, + SparseTensorEncodingAttr::DimLevelType dlt) { + // TODO(wrengr): figure out how to avoid the repetition, yet retain the + // llvm_unreachable check. Then again, I doubt C++ does any sort of + // completeness analysis on switch statements, so there wouldn't be much lost + // by using a default case. + switch (dlt) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + return constantI8(rewriter, loc, (uint8_t)DimLevelType::kDense); + case SparseTensorEncodingAttr::DimLevelType::Compressed: + return constantI8(rewriter, loc, (uint8_t)DimLevelType::kCompressed); + case SparseTensorEncodingAttr::DimLevelType::Singleton: + return constantI8(rewriter, loc, (uint8_t)DimLevelType::kSingleton); + } + llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); +} + /// Returns a function reference (first hit also inserts into module). Sets /// the "_emit_c_interface" on the function declaration when requested, /// so that LLVM lowering generates a wrapper function that takes care @@ -238,7 +242,7 @@ /// computation. static void newParams(ConversionPatternRewriter &rewriter, SmallVector ¶ms, Operation *op, - SparseTensorEncodingAttr &enc, uint32_t action, + SparseTensorEncodingAttr &enc, Action action, ValueRange szs, Value ptr = Value()) { Location loc = op->getLoc(); ArrayRef dlt = enc.getDimLevelType(); @@ -246,7 +250,7 @@ // Sparsity annotations. SmallVector attrs; for (unsigned i = 0; i < sz; i++) - attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i]))); + attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); params.push_back(genBuffer(rewriter, loc, attrs)); // Dimension sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. @@ -268,18 +272,17 @@ params.push_back(genBuffer(rewriter, loc, rev)); // Secondary and primary types encoding. ShapedType resType = op->getResult(0).getType().cast(); - uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); - uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); - uint32_t primary = getPrimaryTypeEncoding(resType.getElementType()); - assert(primary); - params.push_back(constantI32(rewriter, loc, secPtr)); - params.push_back(constantI32(rewriter, loc, secInd)); - params.push_back(constantI32(rewriter, loc, primary)); + params.push_back( + constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth())); + params.push_back( + constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth())); + params.push_back( + constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType())); // User action and pointer. Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); if (!ptr) ptr = rewriter.create(loc, pTp); - params.push_back(constantI32(rewriter, loc, action)); + params.push_back(constantI32(rewriter, loc, (uint32_t)action)); params.push_back(ptr); } @@ -530,7 +533,7 @@ SmallVector params; sizesFromType(rewriter, sizes, op.getLoc(), resType.cast()); Value ptr = adaptor.getOperands()[0]; - newParams(rewriter, params, op, enc, kFromFile, sizes, ptr); + newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); rewriter.replaceOp(op, genNewCall(rewriter, op, params)); return success(); } @@ -549,7 +552,7 @@ // Generate the call to construct empty tensor. The sizes are // explicitly defined by the arguments to the init operator. SmallVector params; - newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands()); + newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); rewriter.replaceOp(op, genNewCall(rewriter, op, params)); return success(); } @@ -588,13 +591,13 @@ auto enc = SparseTensorEncodingAttr::get( op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - newParams(rewriter, params, op, enc, kToCOO, sizes, src); + newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); Value coo = genNewCall(rewriter, op, params); - params[3] = constantI32( - rewriter, loc, getOverheadTypeEncoding(encDst.getPointerBitWidth())); - params[4] = constantI32( - rewriter, loc, getOverheadTypeEncoding(encDst.getIndexBitWidth())); - params[6] = constantI32(rewriter, loc, kFromCOO); + params[3] = constantOverheadTypeEncoding(rewriter, loc, + encDst.getPointerBitWidth()); + params[4] = constantOverheadTypeEncoding(rewriter, loc, + encDst.getIndexBitWidth()); + params[6] = constantI32(rewriter, loc, (uint32_t)Action::kFromCOO); params[7] = coo; rewriter.replaceOp(op, genNewCall(rewriter, op, params)); return success(); @@ -613,7 +616,7 @@ Type elemTp = dstTensorTp.getElementType(); // Fabricate a no-permutation encoding for newParams(). // The pointer/index types must be those of `src`. - // The dimLevelTypes aren't actually used by kToIter. + // The dimLevelTypes aren't actually used by Action::kToIterator. encDst = SparseTensorEncodingAttr::get( op->getContext(), SmallVector( @@ -622,7 +625,7 @@ SmallVector sizes; SmallVector params; sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); - newParams(rewriter, params, op, encDst, kToIter, sizes, src); + newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); Value iter = genNewCall(rewriter, op, params); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); @@ -677,7 +680,7 @@ SmallVector sizes; SmallVector params; sizesFromSrc(rewriter, sizes, loc, src); - newParams(rewriter, params, op, encDst, kEmptyCOO, sizes); + newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); Value ptr = genNewCall(rewriter, op, params); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); Value perm = params[2]; @@ -718,7 +721,7 @@ return {}; }); // Final call to construct sparse tensor storage. - params[6] = constantI32(rewriter, loc, kFromCOO); + params[6] = constantI32(rewriter, loc, (uint32_t)Action::kFromCOO); params[7] = ptr; rewriter.replaceOp(op, genNewCall(rewriter, op, params)); return success(); diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -14,6 +14,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/ExecutionEngine/SparseTensorUtils.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS @@ -162,8 +163,6 @@ /// function overloading to implement "partial" method specialization. class SparseTensorStorageBase { public: - enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 }; - virtual uint64_t getDimSize(uint64_t) = 0; // Overhead storage. @@ -206,7 +205,7 @@ /// permutation, and per-dimension dense/sparse annotations, using /// the coordinate scheme tensor for the initial contents if provided. SparseTensorStorage(const std::vector &szs, const uint64_t *perm, - const uint8_t *sparsity, SparseTensorCOO *tensor) + const DimLevelType *sparsity, SparseTensorCOO *tensor) : sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) { uint64_t rank = getRank(); // Store "reverse" permutation. @@ -216,17 +215,18 @@ // TODO: needs fine-tuning based on sparsity for (uint64_t r = 0, s = 1; r < rank; r++) { s *= sizes[r]; - if (sparsity[r] == kCompressed) { + if (sparsity[r] == DimLevelType::kCompressed) { pointers[r].reserve(s + 1); indices[r].reserve(s); s = 1; } else { - assert(sparsity[r] == kDense && "singleton not yet supported"); + assert(sparsity[r] == DimLevelType::kDense && + "singleton not yet supported"); } } // Prepare sparse pointer structures for all dimensions. for (uint64_t r = 0; r < rank; r++) - if (sparsity[r] == kCompressed) + if (sparsity[r] == DimLevelType::kCompressed) pointers[r].push_back(0); // Then assign contents from coordinate scheme tensor if provided. if (tensor) { @@ -288,7 +288,7 @@ /// permutation as is desired for the new sparse tensor storage. static SparseTensorStorage * newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm, - const uint8_t *sparsity, SparseTensorCOO *tensor) { + const DimLevelType *sparsity, SparseTensorCOO *tensor) { SparseTensorStorage *n = nullptr; if (tensor) { assert(tensor->getRank() == rank); @@ -311,8 +311,8 @@ /// Initializes sparse tensor storage scheme from a memory-resident sparse /// tensor in coordinate scheme. This method prepares the pointers and /// indices arrays under the given per-dimension dense/sparse annotations. - void fromCOO(SparseTensorCOO *tensor, const uint8_t *sparsity, uint64_t lo, - uint64_t hi, uint64_t d) { + void fromCOO(SparseTensorCOO *tensor, const DimLevelType *sparsity, + uint64_t lo, uint64_t hi, uint64_t d) { const std::vector> &elements = tensor->getElements(); // Once dimensions are exhausted, insert the numerical values. if (d == getRank()) { @@ -331,7 +331,7 @@ while (seg < hi && elements[seg].indices[d] == idx) seg++; // Handle segment in interval for sparse or dense dimension. - if (sparsity[d] == kCompressed) { + if (sparsity[d] == DimLevelType::kCompressed) { indices[d].push_back(idx); } else { // For dense storage we must fill in all the zero values between @@ -346,7 +346,7 @@ lo = seg; } // Finalize the sparse pointer structure at this dimension. - if (sparsity[d] == kCompressed) { + if (sparsity[d] == DimLevelType::kCompressed) { pointers[d].push_back(indices[d].size()); } else { // For dense storage we must fill in all the zero values after @@ -543,53 +543,35 @@ // //===----------------------------------------------------------------------===// -enum OverheadTypeEnum : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; - -enum PrimaryTypeEnum : uint32_t { - kF64 = 1, - kF32 = 2, - kI64 = 3, - kI32 = 4, - kI16 = 5, - kI8 = 6 -}; - -enum Action : uint32_t { - kEmpty = 0, - kFromFile = 1, - kFromCOO = 2, - kEmptyCOO = 3, - kToCOO = 4, - kToIter = 5 -}; - #define CASE(p, i, v, P, I, V) \ if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ SparseTensorCOO *tensor = nullptr; \ - if (action <= kFromCOO) { \ - if (action == kFromFile) { \ + if (action <= Action::kFromCOO) { \ + if (action == Action::kFromFile) { \ char *filename = static_cast(ptr); \ tensor = openSparseTensorCOO(filename, rank, sizes, perm); \ - } else if (action == kFromCOO) { \ + } else if (action == Action::kFromCOO) { \ tensor = static_cast *>(ptr); \ } else { \ - assert(action == kEmpty); \ + assert(action == Action::kEmpty); \ } \ return SparseTensorStorage::newSparseTensor(rank, sizes, perm, \ sparsity, tensor); \ - } else if (action == kEmptyCOO) { \ + } else if (action == Action::kEmptyCOO) { \ return SparseTensorCOO::newSparseTensorCOO(rank, sizes, perm); \ } else { \ tensor = static_cast *>(ptr)->toCOO(perm); \ - if (action == kToIter) { \ + if (action == Action::kToIterator) { \ tensor->startIterator(); \ } else { \ - assert(action == kToCOO); \ + assert(action == Action::kToCOO); \ } \ return tensor; \ } \ } +#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) + #define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \ void _mlir_ciface_##NAME(StridedMemRefType *ref, void *tensor) { \ assert(ref); \ @@ -656,78 +638,110 @@ /// Constructs a new sparse tensor. This is the "swiss army knife" /// method for materializing sparse tensors into the computation. /// -/// action: +/// Action: /// kEmpty = returns empty storage to fill later /// kFromFile = returns storage, where ptr contains filename to read /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO -/// kToIter = returns iterator from storage in ptr (call getNext() to use) +/// kToIterator = returns iterator from storage in ptr (call getNext() to use) void * -_mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT +_mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT StridedMemRefType *sref, StridedMemRefType *pref, - uint32_t ptrTp, uint32_t indTp, uint32_t valTp, - uint32_t action, void *ptr) { + OverheadType ptrTp, OverheadType indTp, + PrimaryType valTp, Action action, void *ptr) { assert(aref && sref && pref); assert(aref->strides[0] == 1 && sref->strides[0] == 1 && pref->strides[0] == 1); assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); - const uint8_t *sparsity = aref->data + aref->offset; + const DimLevelType *sparsity = aref->data + aref->offset; const index_t *sizes = sref->data + sref->offset; const index_t *perm = pref->data + pref->offset; uint64_t rank = aref->sizes[0]; // Double matrices with all combinations of overhead storage. - CASE(kU64, kU64, kF64, uint64_t, uint64_t, double); - CASE(kU64, kU32, kF64, uint64_t, uint32_t, double); - CASE(kU64, kU16, kF64, uint64_t, uint16_t, double); - CASE(kU64, kU8, kF64, uint64_t, uint8_t, double); - CASE(kU32, kU64, kF64, uint32_t, uint64_t, double); - CASE(kU32, kU32, kF64, uint32_t, uint32_t, double); - CASE(kU32, kU16, kF64, uint32_t, uint16_t, double); - CASE(kU32, kU8, kF64, uint32_t, uint8_t, double); - CASE(kU16, kU64, kF64, uint16_t, uint64_t, double); - CASE(kU16, kU32, kF64, uint16_t, uint32_t, double); - CASE(kU16, kU16, kF64, uint16_t, uint16_t, double); - CASE(kU16, kU8, kF64, uint16_t, uint8_t, double); - CASE(kU8, kU64, kF64, uint8_t, uint64_t, double); - CASE(kU8, kU32, kF64, uint8_t, uint32_t, double); - CASE(kU8, kU16, kF64, uint8_t, uint16_t, double); - CASE(kU8, kU8, kF64, uint8_t, uint8_t, double); + CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, + uint64_t, double); + CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, + uint32_t, double); + CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, + uint16_t, double); + CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, + uint8_t, double); + CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, + uint64_t, double); + CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, + uint32_t, double); + CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, + uint16_t, double); + CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, + uint8_t, double); + CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, + uint64_t, double); + CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, + uint32_t, double); + CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, + uint16_t, double); + CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, + uint8_t, double); + CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, + uint64_t, double); + CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, + uint32_t, double); + CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, + uint16_t, double); + CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, + uint8_t, double); // Float matrices with all combinations of overhead storage. - CASE(kU64, kU64, kF32, uint64_t, uint64_t, float); - CASE(kU64, kU32, kF32, uint64_t, uint32_t, float); - CASE(kU64, kU16, kF32, uint64_t, uint16_t, float); - CASE(kU64, kU8, kF32, uint64_t, uint8_t, float); - CASE(kU32, kU64, kF32, uint32_t, uint64_t, float); - CASE(kU32, kU32, kF32, uint32_t, uint32_t, float); - CASE(kU32, kU16, kF32, uint32_t, uint16_t, float); - CASE(kU32, kU8, kF32, uint32_t, uint8_t, float); - CASE(kU16, kU64, kF32, uint16_t, uint64_t, float); - CASE(kU16, kU32, kF32, uint16_t, uint32_t, float); - CASE(kU16, kU16, kF32, uint16_t, uint16_t, float); - CASE(kU16, kU8, kF32, uint16_t, uint8_t, float); - CASE(kU8, kU64, kF32, uint8_t, uint64_t, float); - CASE(kU8, kU32, kF32, uint8_t, uint32_t, float); - CASE(kU8, kU16, kF32, uint8_t, uint16_t, float); - CASE(kU8, kU8, kF32, uint8_t, uint8_t, float); + CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, + uint64_t, float); + CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, + uint32_t, float); + CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, + uint16_t, float); + CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, + uint8_t, float); + CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, + uint64_t, float); + CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, + uint32_t, float); + CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, + uint16_t, float); + CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, + uint8_t, float); + CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, + uint64_t, float); + CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, + uint32_t, float); + CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, + uint16_t, float); + CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, + uint8_t, float); + CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, + uint64_t, float); + CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, + uint32_t, float); + CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, + uint16_t, float); + CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, + uint8_t, float); - // Integral matrices with same overhead storage. - CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t); - CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t); - CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t); - CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t); - CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t); - CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t); - CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t); - CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t); - CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t); - CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t); - CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t); - CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t); - CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t); + // Integral matrices with both overheads of the same type. + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); + CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); + CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); + CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); + CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); + CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); + CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); + CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); + CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); + CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); // Unsupported case (add above if needed). fputs("unsupported combination of types\n", stderr); @@ -830,7 +844,7 @@ void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, double *values, uint64_t *indices) { // Setup all-dims compressed and default ordering. - std::vector sparse(rank, SparseTensorStorageBase::kCompressed); + std::vector sparse(rank, DimLevelType::kCompressed); std::vector perm(rank); std::iota(perm.begin(), perm.end(), 0); // Convert external format to internal COO. diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1707,7 +1707,10 @@ cc_library( name = "SparseTensorTransforms", srcs = glob(["lib/Dialect/SparseTensor/Transforms/*.cpp"]), - hdrs = ["include/mlir/Dialect/SparseTensor/Transforms/Passes.h"], + hdrs = [ + "include/mlir/Dialect/SparseTensor/Transforms/Passes.h", + "include/mlir/ExecutionEngine/SparseTensorUtils.h", + ], includes = ["include"], deps = [ ":Affine", @@ -5391,7 +5394,10 @@ "lib/ExecutionEngine/CRunnerUtils.cpp", "lib/ExecutionEngine/SparseTensorUtils.cpp", ], - hdrs = ["include/mlir/ExecutionEngine/CRunnerUtils.h"], + hdrs = [ + "include/mlir/ExecutionEngine/CRunnerUtils.h", + "include/mlir/ExecutionEngine/SparseTensorUtils.h", + ], includes = ["include"], )