diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -138,7 +138,13 @@ /// versions; consequently, client code should use the predicate functions /// defined below, rather than relying on knowledge about the particular /// binary encoding. +/// +/// The `Undef` "format" is a special value used internally for cases +/// where we need to store an undefined or indeterminate `DimLevelType`. +/// It should not be used externally, since it does not indicate an +/// actual/representable format. enum class DimLevelType : uint8_t { + Undef = 0, // 0b000_00 Dense = 4, // 0b001_00 Compressed = 8, // 0b010_00 CompressedNu = 9, // 0b010_01 @@ -150,20 +156,39 @@ SingletonNuNo = 19, // 0b100_11 }; +/// Check that the `DimLevelType` contains a valid (possibly undefined) value. +constexpr bool isValidDLT(DimLevelType dlt) { + const uint8_t formatBits = static_cast(dlt) >> 2; + const uint8_t propertyBits = static_cast(dlt) & 3; + // If undefined or dense, then must be unique and ordered. + // Otherwise, the format must be one of the known ones. + return (formatBits <= 1) ? (propertyBits == 0) + : (formatBits == 2 || formatBits == 4); +} + +/// Check if the `DimLevelType` is the special undefined value. +constexpr bool isUndefDLT(DimLevelType dlt) { + return dlt == DimLevelType::Undef; +} + /// Check if the `DimLevelType` is dense. constexpr bool isDenseDLT(DimLevelType dlt) { return dlt == DimLevelType::Dense; } +// We use the idiom `(dlt & ~3) == format` in order to only return true +// for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but +// can return false-positives on invalid DLTs. + /// Check if the `DimLevelType` is compressed (regardless of properties). constexpr bool isCompressedDLT(DimLevelType dlt) { - return static_cast(dlt) & + return (static_cast(dlt) & ~3) == static_cast(DimLevelType::Compressed); } /// Check if the `DimLevelType` is singleton (regardless of properties). constexpr bool isSingletonDLT(DimLevelType dlt) { - return static_cast(dlt) & + return (static_cast(dlt) & ~3) == static_cast(DimLevelType::Singleton); } @@ -178,6 +203,18 @@ } // Ensure the above predicates work as intended. +static_assert((isValidDLT(DimLevelType::Undef) && + isValidDLT(DimLevelType::Dense) && + isValidDLT(DimLevelType::Compressed) && + isValidDLT(DimLevelType::CompressedNu) && + isValidDLT(DimLevelType::CompressedNo) && + isValidDLT(DimLevelType::CompressedNuNo) && + isValidDLT(DimLevelType::Singleton) && + isValidDLT(DimLevelType::SingletonNu) && + isValidDLT(DimLevelType::SingletonNo) && + isValidDLT(DimLevelType::SingletonNuNo)), + "isValidDLT definition is broken"); + static_assert((!isCompressedDLT(DimLevelType::Dense) && isCompressedDLT(DimLevelType::Compressed) && isCompressedDLT(DimLevelType::CompressedNu) && diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -37,17 +37,20 @@ // Dimension level types. // -// Cannot be constexpr, because `getRank` isn't constexpr. However, -// for some strange reason, the wrapper functions below don't trigger -// the same [-Winvalid-constexpr] warning (despite this function not -// being constexpr). -inline DimLevelType getDimLevelType(RankedTensorType type, uint64_t d) { - assert(d < static_cast(type.getRank())); - if (auto enc = getSparseTensorEncoding(type)) - return enc.getDimLevelType()[d]; +constexpr DimLevelType getDimLevelType(const SparseTensorEncodingAttr &enc, + uint64_t d) { + if (enc) { + auto types = enc.getDimLevelType(); + assert(d < types.size() && "Dimension out of bounds"); + return types[d]; + } return DimLevelType::Dense; // unannotated tensor is dense } +constexpr DimLevelType getDimLevelType(RankedTensorType type, uint64_t d) { + return getDimLevelType(getSparseTensorEncoding(type), d); +} + /// Convenience function to test for dense dimension (0 <= d < rank). constexpr bool isDenseDim(RankedTensorType type, uint64_t d) { return isDenseDLT(getDimLevelType(type, d)); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -14,28 +14,13 @@ #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/IR/Value.h" #include "llvm/ADT/BitVector.h" namespace mlir { namespace sparse_tensor { -/// Dimension level type for a tensor (undef means index does not appear). -enum class DimLvlType { kDense, kCompressed, kSingleton, kUndef }; - -/// Per-dimension level format (type and properties). Dense and undefined -/// level types should always be marked ordered and unique. -struct DimLevelFormat { - DimLevelFormat(DimLvlType tp, bool o = true, bool u = true) - : levelType(tp), isOrdered(o), isUnique(u) { - assert((tp == DimLvlType::kCompressed || tp == DimLvlType::kSingleton) || - (o && u)); - } - DimLvlType levelType; - bool isOrdered; - bool isUnique; -}; - /// Tensor expression kind. enum Kind { // Leaf. @@ -171,8 +156,7 @@ Merger(unsigned t, unsigned l) : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), hasSparseOut(false), - dims(t + 1, std::vector( - l, DimLevelFormat(DimLvlType::kUndef))) {} + dimTypes(t + 1, std::vector(l, DimLevelType::Undef)) {} /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), @@ -252,28 +236,24 @@ /// sparse vector a. bool isSingleCondition(unsigned t, unsigned e) const; - /// Returns true if bit corresponds to given dimension level type. - bool isDimLevelType(unsigned b, DimLvlType tp) const { - return isDimLevelType(tensor(b), index(b), tp); - } - - /// Returns true if tensor access at index has given dimension level type. - bool isDimLevelType(unsigned t, unsigned i, DimLvlType tp) const { - return getDimLevelFormat(t, i).levelType == tp; - } - /// Returns true if any set bit corresponds to sparse dimension level type. bool hasAnySparse(const BitVector &bits) const; - /// Dimension level format getter. - DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const { + /// Gets the dimension level type of the `i`th loop of the `t`th tensor. + DimLevelType getDimLevelType(unsigned t, unsigned i) const { assert(t < numTensors && i < numLoops); - return dims[t][i]; + return dimTypes[t][i]; } - /// Dimension level format setter. - void setDimLevelFormat(unsigned t, unsigned i, DimLevelFormat d) { - dims[t][i] = d; + /// Gets the dimension level type of `b`. + DimLevelType getDimLevelType(unsigned b) const { + return getDimLevelType(tensor(b), index(b)); + } + + /// Sets the dimension level type of the `i`th loop of the `t`th tensor. + void setDimLevelType(unsigned t, unsigned i, DimLevelType d) { + assert(isValidDLT(d)); + dimTypes[t][i] = d; } // Has sparse output tensor setter. @@ -323,7 +303,7 @@ const unsigned numTensors; const unsigned numLoops; bool hasSparseOut; - std::vector> dims; + std::vector> dimTypes; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector, 8> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -143,6 +143,10 @@ printer << "<{ dimLevelType = [ "; for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { switch (getDimLevelType()[i]) { + case DimLevelType::Undef: + // TODO: should probably raise an error instead of printing it... + printer << "\"undef\""; + break; case DimLevelType::Dense: printer << "\"dense\""; break; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -128,60 +128,29 @@ return AffineMap::getPermutationMap(perm, context); } -/// Helper method to obtain the dimension level format from the encoding. -// -// TODO: note that we store, but currently completely *ignore* the properties -// -static DimLevelFormat toDimLevelFormat(const SparseTensorEncodingAttr &enc, - unsigned d) { - if (enc) { - switch (enc.getDimLevelType()[d]) { - case DimLevelType::Dense: - return DimLevelFormat(DimLvlType::kDense); - case DimLevelType::Compressed: - return DimLevelFormat(DimLvlType::kCompressed); - case DimLevelType::CompressedNu: - return DimLevelFormat(DimLvlType::kCompressed, true, false); - case DimLevelType::CompressedNo: - return DimLevelFormat(DimLvlType::kCompressed, false, true); - case DimLevelType::CompressedNuNo: - return DimLevelFormat(DimLvlType::kCompressed, false, false); - case DimLevelType::Singleton: - return DimLevelFormat(DimLvlType::kSingleton); - case DimLevelType::SingletonNu: - return DimLevelFormat(DimLvlType::kSingleton, true, false); - case DimLevelType::SingletonNo: - return DimLevelFormat(DimLvlType::kSingleton, false, true); - case DimLevelType::SingletonNuNo: - return DimLevelFormat(DimLvlType::kSingleton, false, false); - } - } - return DimLevelFormat(DimLvlType::kDense); -} - /// Helper method to inspect affine expressions. Rejects cases where the /// same index is used more than once. Also rejects compound affine /// expressions in sparse dimensions. static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, - DimLevelFormat dim) { + DimLevelType dim) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); - if (!merger.isDimLevelType(tensor, idx, DimLvlType::kUndef)) + if (!isUndefDLT(merger.getDimLevelType(tensor, idx))) return false; // used more than once - merger.setDimLevelFormat(tensor, idx, dim); + merger.setDimLevelType(tensor, idx, dim); return true; } case AffineExprKind::Add: case AffineExprKind::Mul: { - if (dim.levelType != DimLvlType::kDense) + if (!isDenseDLT(dim)) return false; // compound only in dense dim auto binOp = a.cast(); return findAffine(merger, tensor, binOp.getLHS(), dim) && findAffine(merger, tensor, binOp.getRHS(), dim); } case AffineExprKind::Constant: - return dim.levelType == DimLvlType::kDense; // const only in dense dim + return isDenseDLT(dim); // const only in dense dim default: return false; } @@ -203,7 +172,7 @@ for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t->getOperandNumber(); AffineExpr a = map.getResult(toOrigDim(enc, d)); - if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) + if (!findAffine(merger, tensor, a, getDimLevelType(enc, d))) return false; // inadmissible affine expression } } @@ -316,16 +285,16 @@ if (mask & SortMask::kIncludeUndef) { unsigned tensor = t->getOperandNumber(); for (unsigned i = 0; i < n; i++) - if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || - merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) { + if (isCompressedDLT(merger.getDimLevelType(tensor, i)) || + isSingletonDLT(merger.getDimLevelType(tensor, i))) { for (unsigned j = 0; j < n; j++) - if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) { + if (isUndefDLT(merger.getDimLevelType(tensor, j))) { adjM[i][j] = true; inDegree[j]++; } } else { - assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) || - merger.isDimLevelType(tensor, i, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(tensor, i)) || + isUndefDLT(merger.getDimLevelType(tensor, i))); } } } @@ -364,13 +333,13 @@ auto iteratorTypes = op.getIteratorTypesArray(); unsigned numLoops = iteratorTypes.size(); for (unsigned i = 0; i < numLoops; i++) - if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || - merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) { + if (isCompressedDLT(merger.getDimLevelType(tensor, i)) || + isSingletonDLT(merger.getDimLevelType(tensor, i))) { allDense = false; break; } else { - assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) || - merger.isDimLevelType(tensor, i, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(tensor, i)) || + isUndefDLT(merger.getDimLevelType(tensor, i))); } if (allDense) return true; @@ -552,7 +521,7 @@ continue; // compound unsigned idx = a.cast().getPosition(); // Handle the different storage schemes. - if (merger.isDimLevelType(tensor, idx, DimLvlType::kCompressed)) { + if (isCompressedDLT(merger.getDimLevelType(tensor, idx))) { // Compressed dimension, fetch pointer and indices. auto ptrTp = MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); @@ -563,7 +532,7 @@ builder.create(loc, ptrTp, t->get(), dim); codegen.indices[tensor][idx] = builder.create(loc, indTp, t->get(), dim); - } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) { + } else if (isSingletonDLT(merger.getDimLevelType(tensor, idx))) { // Singleton dimension, fetch indices. auto indTp = MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); @@ -572,7 +541,7 @@ builder.create(loc, indTp, t->get(), dim); } else { // Dense dimension, nothing to fetch. - assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense)); + assert(isDenseDLT(merger.getDimLevelType(tensor, idx))); } // Find upper bound in current dimension. unsigned p = toOrigDim(enc, d); @@ -1195,7 +1164,7 @@ continue; unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); - if (merger.isDimLevelType(b, DimLvlType::kCompressed)) { + if (isCompressedDLT(merger.getDimLevelType(b))) { // Initialize sparse index that will implement the iteration: // for pidx_idx = pointers(pidx_idx-1), pointers(1+pidx_idx-1) unsigned pat = at; @@ -1210,7 +1179,7 @@ codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); Value p1 = builder.create(loc, p0, one); codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); - } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) { + } else if (isSingletonDLT(merger.getDimLevelType(b))) { // Initialize sparse index that will implement the "iteration": // for pidx_idx = pidx_idx-1, 1+pidx_idx-1 // We rely on subsequent loop unrolling to get rid of the loop @@ -1226,8 +1195,8 @@ codegen.pidxs[tensor][idx] = p0; codegen.highs[tensor][idx] = builder.create(loc, p0, one); } else { - assert(merger.isDimLevelType(b, DimLvlType::kDense) || - merger.isDimLevelType(b, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(b)) || + isUndefDLT(merger.getDimLevelType(b))); // Dense index still in play. needsUniv = true; } @@ -1316,8 +1285,8 @@ assert(idx == merger.index(fb)); auto iteratorTypes = op.getIteratorTypesArray(); bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]); - bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed) || - merger.isDimLevelType(fb, DimLvlType::kSingleton); + bool isSparse = isCompressedDLT(merger.getDimLevelType(fb)) || + isSingletonDLT(merger.getDimLevelType(fb)); bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && denseUnitStrides(merger, op, idx); bool isParallel = @@ -1392,15 +1361,15 @@ for (unsigned b = 0, be = indices.size(); b < be; b++) { if (!indices[b]) continue; - if (merger.isDimLevelType(b, DimLvlType::kCompressed) || - merger.isDimLevelType(b, DimLvlType::kSingleton)) { + if (isCompressedDLT(merger.getDimLevelType(b)) || + isSingletonDLT(merger.getDimLevelType(b))) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); types.push_back(indexType); operands.push_back(codegen.pidxs[tensor][idx]); } else { - assert(merger.isDimLevelType(b, DimLvlType::kDense) || - merger.isDimLevelType(b, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(b)) || + isUndefDLT(merger.getDimLevelType(b))); } } if (codegen.redVal) { @@ -1431,8 +1400,8 @@ for (unsigned b = 0, be = indices.size(); b < be; b++) { if (!indices[b]) continue; - if (merger.isDimLevelType(b, DimLvlType::kCompressed) || - merger.isDimLevelType(b, DimLvlType::kSingleton)) { + if (isCompressedDLT(merger.getDimLevelType(b)) || + isSingletonDLT(merger.getDimLevelType(b))) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value op1 = before->getArgument(o); @@ -1442,8 +1411,8 @@ cond = cond ? builder.create(loc, cond, opc) : opc; codegen.pidxs[tensor][idx] = after->getArgument(o++); } else { - assert(merger.isDimLevelType(b, DimLvlType::kDense) || - merger.isDimLevelType(b, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(b)) || + isUndefDLT(merger.getDimLevelType(b))); } } if (codegen.redVal) @@ -1486,8 +1455,8 @@ for (unsigned b = 0, be = locals.size(); b < be; b++) { if (!locals[b]) continue; - if (merger.isDimLevelType(b, DimLvlType::kCompressed) || - merger.isDimLevelType(b, DimLvlType::kSingleton)) { + if (isCompressedDLT(merger.getDimLevelType(b)) || + isSingletonDLT(merger.getDimLevelType(b))) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value ptr = codegen.indices[tensor][idx]; @@ -1504,8 +1473,8 @@ } } } else { - assert(merger.isDimLevelType(b, DimLvlType::kDense) || - merger.isDimLevelType(b, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(b)) || + isUndefDLT(merger.getDimLevelType(b))); } } @@ -1520,7 +1489,7 @@ // but may be needed for linearized codegen. for (unsigned b = 0, be = locals.size(); b < be; b++) { if ((locals[b] || merger.isOutTensor(b, idx)) && - merger.isDimLevelType(b, DimLvlType::kDense)) { + isDenseDLT(merger.getDimLevelType(b))) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); unsigned pat = at; @@ -1572,8 +1541,8 @@ for (unsigned b = 0, be = induction.size(); b < be; b++) { if (!induction[b]) continue; - if (merger.isDimLevelType(b, DimLvlType::kCompressed) || - merger.isDimLevelType(b, DimLvlType::kSingleton)) { + if (isCompressedDLT(merger.getDimLevelType(b)) || + isSingletonDLT(merger.getDimLevelType(b))) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value op1 = codegen.idxs[tensor][idx]; @@ -1585,8 +1554,8 @@ operands.push_back(builder.create(loc, cmp, add, op3)); codegen.pidxs[tensor][idx] = whileOp->getResult(o++); } else { - assert(merger.isDimLevelType(b, DimLvlType::kDense) || - merger.isDimLevelType(b, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(b)) || + isUndefDLT(merger.getDimLevelType(b))); } } if (codegen.redVal) { @@ -1641,15 +1610,15 @@ unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value clause; - if (merger.isDimLevelType(b, DimLvlType::kCompressed) || - merger.isDimLevelType(b, DimLvlType::kSingleton)) { + if (isCompressedDLT(merger.getDimLevelType(b)) || + isSingletonDLT(merger.getDimLevelType(b))) { Value op1 = codegen.idxs[tensor][idx]; Value op2 = codegen.loops[idx]; clause = builder.create(loc, arith::CmpIPredicate::eq, op1, op2); } else { - assert(merger.isDimLevelType(b, DimLvlType::kDense) || - merger.isDimLevelType(b, DimLvlType::kUndef)); + assert(isDenseDLT(merger.getDimLevelType(b)) || + isUndefDLT(merger.getDimLevelType(b))); clause = constantI1(builder, loc, true); } cond = cond ? builder.create(loc, cond, clause) : clause; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt @@ -9,4 +9,5 @@ MLIRComplexDialect MLIRIR MLIRLinalgDialect + MLIRSparseTensorEnums ) diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -271,7 +271,7 @@ // Starts resetting from a dense dimension, so that the first bit (if kept) // is not undefined dimension type. for (unsigned b = 0; b < be; b++) { - if (simple[b] && isDimLevelType(b, DimLvlType::kDense)) { + if (simple[b] && isDenseDLT(getDimLevelType(b))) { offset = be - b - 1; // relative to the end break; } @@ -281,8 +281,8 @@ // keep the rightmost bit (which could possibly be a synthetic tensor). for (unsigned b = be - 1 - offset, i = 0; i < be; b = b == 0 ? be - 1 : b - 1, i++) { - if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) && - !isDimLevelType(b, DimLvlType::kSingleton))) { + if (simple[b] && (!isCompressedDLT(getDimLevelType(b)) && + !isSingletonDLT(getDimLevelType(b)))) { if (reset) simple.reset(b); reset = true; @@ -396,8 +396,8 @@ bool Merger::hasAnySparse(const BitVector &bits) const { for (unsigned b = 0, be = bits.size(); b < be; b++) - if (bits[b] && (isDimLevelType(b, DimLvlType::kCompressed) || - isDimLevelType(b, DimLvlType::kSingleton))) + if (bits[b] && (isCompressedDLT(getDimLevelType(b)) || + isSingletonDLT(getDimLevelType(b)))) return true; return false; } @@ -613,23 +613,18 @@ if (bits[b]) { unsigned t = tensor(b); unsigned i = index(b); - DimLevelFormat f = dims[t][i]; + DimLevelType dlt = dimTypes[t][i]; llvm::dbgs() << " i_" << t << "_" << i << "_"; - switch (f.levelType) { - case DimLvlType::kDense: + if (isDenseDLT(dlt)) llvm::dbgs() << "D"; - break; - case DimLvlType::kCompressed: + else if (isCompressedDLT(dlt)) llvm::dbgs() << "C"; - break; - case DimLvlType::kSingleton: + else if (isSingletonDLT(dlt)) llvm::dbgs() << "S"; - break; - case DimLvlType::kUndef: + else if (isUndefDLT(dlt)) llvm::dbgs() << "U"; - break; - } - llvm::dbgs() << "[O=" << f.isOrdered << ",U=" << f.isUnique << "]"; + llvm::dbgs() << "[O=" << isOrderedDLT(dlt) << ",U=" << isUniqueDLT(dlt) + << "]"; } } } diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -311,15 +311,15 @@ MergerTest3T1L() : MergerTestBase(3, 1) { // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed)); + merger.setDimLevelType(t0, l0, DimLevelType::Compressed); // Tensor 1: sparse input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed)); + merger.setDimLevelType(t1, l0, DimLevelType::Compressed); // Tensor 2: dense output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense)); + merger.setDimLevelType(t2, l0, DimLevelType::Dense); } }; @@ -334,19 +334,19 @@ MergerTest4T1L() : MergerTestBase(4, 1) { // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed)); + merger.setDimLevelType(t0, l0, DimLevelType::Compressed); // Tensor 1: sparse input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed)); + merger.setDimLevelType(t1, l0, DimLevelType::Compressed); // Tensor 2: sparse input vector merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed)); + merger.setDimLevelType(t2, l0, DimLevelType::Compressed); // Tensor 3: dense output vector merger.addExp(Kind::kTensor, t3, -1u); - merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense)); + merger.setDimLevelType(t3, l0, DimLevelType::Dense); } }; @@ -365,15 +365,15 @@ MergerTest3T1LD() : MergerTestBase(3, 1) { // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed)); + merger.setDimLevelType(t0, l0, DimLevelType::Compressed); // Tensor 1: dense input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense)); + merger.setDimLevelType(t1, l0, DimLevelType::Dense); // Tensor 2: dense output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense)); + merger.setDimLevelType(t2, l0, DimLevelType::Dense); } }; @@ -392,19 +392,19 @@ MergerTest4T1LU() : MergerTestBase(4, 1) { // Tensor 0: undef input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef)); + merger.setDimLevelType(t0, l0, DimLevelType::Undef); // Tensor 1: dense input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense)); + merger.setDimLevelType(t1, l0, DimLevelType::Dense); // Tensor 2: undef input vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kUndef)); + merger.setDimLevelType(t2, l0, DimLevelType::Undef); // Tensor 3: dense output vector. merger.addExp(Kind::kTensor, t3, -1u); - merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense)); + merger.setDimLevelType(t3, l0, DimLevelType::Dense); } }; @@ -425,15 +425,15 @@ // Tensor 0: undef input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef)); + merger.setDimLevelType(t0, l0, DimLevelType::Undef); // Tensor 1: undef input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kUndef)); + merger.setDimLevelType(t1, l0, DimLevelType::Undef); // Tensor 2: sparse output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed)); + merger.setDimLevelType(t2, l0, DimLevelType::Compressed); } }; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2104,6 +2104,7 @@ ":LinalgDialect", ":MathDialect", ":SparseTensorDialect", + ":SparseTensorEnums", "//llvm:Support", ], )