diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -21,7 +21,20 @@ namespace sparse_tensor { /// Dimension level type for a tensor (undef means index does not appear). -enum Dim { kSparse, kDense, kUndef }; +enum class DimLvlType { kDense, kCompressed, kSingleton, kUndef }; + +/// Per-dimension level format (type and properties). Dense and undefined +/// level types should always be marked ordered and unique. +struct DimLevelFormat { + DimLevelFormat(DimLvlType tp, bool o = true, bool u = true) + : levelType(tp), isOrdered(o), isUnique(u) { + assert((tp == DimLvlType::kCompressed || tp == DimLvlType::kSingleton) || + (o && u)); + } + DimLvlType levelType; + bool isOrdered; + bool isUnique; +}; /// Tensor expression kind. enum Kind { @@ -156,7 +169,9 @@ /// invariant expressions in the kernel. Merger(unsigned t, unsigned l) : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), - hasSparseOut(false), dims(t + 1, std::vector(l, Dim::kUndef)) {} + hasSparseOut(false), + dims(t + 1, std::vector( + l, DimLevelFormat(DimLvlType::kUndef))) {} /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), @@ -225,31 +240,40 @@ unsigned tensor(unsigned b) const { return b % numTensors; } unsigned index(unsigned b) const { return b / numTensors; } - /// Returns true if bit corresponds to queried dim. - bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } - /// Returns true if bit corresponds to index of output tensor. bool isOutTensor(unsigned b, unsigned i) const { return tensor(b) == outTensor && index(b) == i; } - /// Returns true if tensor access at given index has queried dim. - bool isDim(unsigned t, unsigned i, Dim d) const { - assert(t < numTensors && i < numLoops); - return dims[t][i] == d; - } - - /// Returns true if any set bit corresponds to queried dim. - bool hasAnyDimOf(const BitVector &bits, Dim d) const; - /// Returns true if given tensor iterates *only* in the given tensor /// expression. For the output tensor, this defines a "simply dynamic" /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for /// sparse vector a. bool isSingleCondition(unsigned t, unsigned e) const; - /// Dimension setter. - void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } + /// Returns true if bit corresponds to given dimension level type. + bool isDimLevelType(unsigned b, DimLvlType tp) const { + return isDimLevelType(tensor(b), index(b), tp); + } + + /// Returns true if tensor access at index has given dimension level type. + bool isDimLevelType(unsigned t, unsigned i, DimLvlType tp) const { + return getDimLevelFormat(t, i).levelType == tp; + } + + /// Returns true if any set bit corresponds to given dimension level type. + bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const; + + /// Dimension level format getter. + DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const { + assert(t < numTensors && i < numLoops); + return dims[t][i]; + } + + /// Dimension level format setter. + void setDimLevelFormat(unsigned t, unsigned i, DimLevelFormat d) { + dims[t][i] = d; + } // Has sparse output tensor setter. void setHasSparseOut(bool s) { hasSparseOut = s; } @@ -298,7 +322,7 @@ const unsigned numTensors; const unsigned numLoops; bool hasSparseOut; - std::vector> dims; + std::vector> dims; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector, 8> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -141,40 +141,60 @@ return d; } -/// Helper method to translate dim level type to internal representation. -static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { +/// Helper method to obtain the dimension level format from the encoding. +// +// TODO: note that we store, but currently completely *ignore* the properties +// +static DimLevelFormat toDimLevelFormat(const SparseTensorEncodingAttr &enc, + unsigned d) { if (enc) { - SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; - if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) - return Dim::kSparse; + switch (enc.getDimLevelType()[d]) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + return DimLevelFormat(DimLvlType::kDense); + case SparseTensorEncodingAttr::DimLevelType::Compressed: + return DimLevelFormat(DimLvlType::kCompressed); + case SparseTensorEncodingAttr::DimLevelType::CompressedNu: + return DimLevelFormat(DimLvlType::kCompressed, true, false); + case SparseTensorEncodingAttr::DimLevelType::CompressedNo: + return DimLevelFormat(DimLvlType::kCompressed, false, true); + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + return DimLevelFormat(DimLvlType::kCompressed, false, false); + case SparseTensorEncodingAttr::DimLevelType::Singleton: + return DimLevelFormat(DimLvlType::kSingleton); + case SparseTensorEncodingAttr::DimLevelType::SingletonNu: + return DimLevelFormat(DimLvlType::kSingleton, true, false); + case SparseTensorEncodingAttr::DimLevelType::SingletonNo: + return DimLevelFormat(DimLvlType::kSingleton, false, true); + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + return DimLevelFormat(DimLvlType::kSingleton, false, false); + } } - return Dim::kDense; + return DimLevelFormat(DimLvlType::kDense); } /// Helper method to inspect affine expressions. Rejects cases where the -/// same index is used more than once. Also rejects affine expressions -/// that are not a direct index for annotated tensors. -// TODO: accept more affine cases for sparse tensors -static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, - bool isDense) { +/// same index is used more than once. Also rejects compound affine +/// expressions in sparse dimensions. +static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, + DimLevelFormat dim) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); - if (!merger.isDim(tensor, idx, Dim::kUndef)) + if (!merger.isDimLevelType(tensor, idx, DimLvlType::kUndef)) return false; // used more than once - merger.setDim(tensor, idx, dim); + merger.setDimLevelFormat(tensor, idx, dim); return true; } case AffineExprKind::Add: case AffineExprKind::Mul: { - if (!isDense) - return false; + if (dim.levelType != DimLvlType::kDense) + return false; // compound only in dense dim auto binOp = a.cast(); - return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && - findAffine(merger, tensor, binOp.getRHS(), dim, isDense); + return findAffine(merger, tensor, binOp.getLHS(), dim) && + findAffine(merger, tensor, binOp.getRHS(), dim); } case AffineExprKind::Constant: - return isDense; + return dim.levelType == DimLvlType::kDense; // const only in dense dim default: return false; } @@ -196,7 +216,7 @@ for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t->getOperandNumber(); AffineExpr a = map.getResult(perm(enc, d)); - if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) + if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) return false; // inadmissable affine expression } } @@ -286,13 +306,13 @@ if (mask & SortMask::kIncludeUndef) { unsigned tensor = t->getOperandNumber(); for (unsigned i = 0; i < n; i++) - if (merger.isDim(tensor, i, Dim::kSparse)) + if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || + merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) for (unsigned j = 0; j < n; j++) - if (merger.isDim(tensor, j, Dim::kUndef)) + if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) adjM[i][j] = true; } } - // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. topSort.clear(); @@ -334,7 +354,8 @@ auto iteratorTypes = op.iterator_types().getValue(); unsigned numLoops = iteratorTypes.size(); for (unsigned i = 0; i < numLoops; i++) - if (merger.isDim(tensor, i, Dim::kSparse)) { + if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || + merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) { allDense = false; break; } @@ -519,7 +540,7 @@ continue; // compound unsigned idx = a.cast().getPosition(); // Handle sparse storage schemes. - if (merger.isDim(tensor, idx, Dim::kSparse)) { + if (merger.isDimLevelType(tensor, idx, DimLvlType::kCompressed)) { auto dynShape = {ShapedType::kDynamicSize}; auto ptrTp = MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); @@ -531,6 +552,8 @@ builder.create(loc, ptrTp, t->get(), dim); codegen.indices[tensor][idx] = builder.create(loc, indTp, t->get(), dim); + } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) { + llvm_unreachable("TODO: not implemented yet"); } // Find upper bound in current dimension. unsigned p = perm(enc, d); @@ -543,7 +566,6 @@ // Perform the required bufferization. Dense inputs materialize // from the input tensors. Dense outputs need special handling. // Sparse inputs use sparse primitives to obtain the values. - // We also accept in-place all-dense annotated "sparse" outputs. Type elementType = getElementTypeOrSelf(t->get().getType()); if (!enc) { // Non-annotated dense tensors. @@ -985,11 +1007,13 @@ return genInvariantValue(merger, codegen, rewriter, exp); if (merger.exp(exp).kind == Kind::kIndex) return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); + if (merger.exp(exp).kind == Kind::kReduce) { // Make custom reduction identity accessible for expanded access pattern. assert(codegen.redCustom == -1u); codegen.redCustom = exp; } + Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); Value v1 = @@ -1000,8 +1024,12 @@ merger.exp(exp).kind == Kind::kBinaryBranch || merger.exp(exp).kind == Kind::kReduce)) ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); - if (merger.exp(exp).kind == Kind::kReduce) + + if (merger.exp(exp).kind == Kind::kReduce) { + assert(codegen.redCustom != -1u); codegen.redCustom = -1u; + } + return ee; } @@ -1029,7 +1057,7 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned ldx, - bool atStart, unsigned last = 0) { + bool atStart, unsigned last = -1u) { if (exp == -1u) return; if (merger.exp(exp).kind == Kind::kTensor) { @@ -1131,7 +1159,7 @@ if (inits[b]) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); - if (merger.isDim(b, Dim::kSparse)) { + if (merger.isDimLevelType(b, DimLvlType::kCompressed)) { // Initialize sparse index. unsigned pat = at; for (; pat != 0; pat--) { @@ -1145,6 +1173,8 @@ codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); Value p1 = builder.create(loc, p0, one); codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); + } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) { + llvm_unreachable("TODO: not implemented yet"); } else { // Dense index still in play. needsUniv = true; @@ -1235,12 +1265,14 @@ assert(idx == merger.index(fb)); auto iteratorTypes = op.iterator_types().getValue(); bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]); - bool isSparse = merger.isDim(fb, Dim::kSparse); + bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed); bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && denseUnitStrides(merger, op, idx); bool isParallel = isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); + assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && "TODO: implement"); + // Prepare vector length. if (isVector) codegen.curVecLength = codegen.options.vectorLength; @@ -1308,7 +1340,7 @@ // Construct the while-loop with a parameter for each index. Type indexType = builder.getIndexType(); for (unsigned b = 0, be = indices.size(); b < be; b++) { - if (indices[b] && merger.isDim(b, Dim::kSparse)) { + if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); types.push_back(indexType); @@ -1341,7 +1373,8 @@ Value cond; unsigned o = 0; for (unsigned b = 0, be = indices.size(); b < be; b++) { - if (indices[b] && merger.isDim(b, Dim::kSparse)) { + // TODO: singleton + if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value op1 = before->getArgument(o); @@ -1389,7 +1422,8 @@ // Initialize sparse indices. Value min; for (unsigned b = 0, be = locals.size(); b < be; b++) { - if (locals[b] && merger.isDim(b, Dim::kSparse)) { + // TODO: singleton + if (locals[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value ptr = codegen.indices[tensor][idx]; @@ -1419,7 +1453,7 @@ // but may be needed for linearized codegen. for (unsigned b = 0, be = locals.size(); b < be; b++) { if ((locals[b] || merger.isOutTensor(b, idx)) && - merger.isDim(b, Dim::kDense)) { + merger.isDimLevelType(b, DimLvlType::kDense)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); unsigned pat = at; @@ -1477,7 +1511,8 @@ SmallVector operands; Value one = constantIndex(builder, loc, 1); for (unsigned b = 0, be = induction.size(); b < be; b++) { - if (induction[b] && merger.isDim(b, Dim::kSparse)) { + // TODO: singleton + if (induction[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value op1 = codegen.idxs[tensor][idx]; @@ -1541,7 +1576,8 @@ unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value clause; - if (merger.isDim(b, Dim::kSparse)) { + // TODO: singleton + if (merger.isDimLevelType(b, DimLvlType::kCompressed)) { Value op1 = codegen.idxs[tensor][idx]; Value op2 = codegen.loops[idx]; clause = builder.create(loc, arith::CmpIPredicate::eq, @@ -1605,7 +1641,8 @@ unsigned lsize = merger.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { unsigned li = merger.set(lts)[i]; - if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) + if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) && + !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton)) return true; } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -262,10 +262,17 @@ } } // Now apply the two basic rules. + // + // TODO: improve for singleton and properties + // BitVector simple = latPoints[p0].bits; - bool reset = isSingleton && hasAnyDimOf(simple, kSparse); + bool reset = isSingleton && + (hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) || + hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton)); for (unsigned b = 0, be = simple.size(); b < be; b++) { - if (simple[b] && !isDim(b, kSparse)) { + if (simple[b] && + (!isDimLevelType(b, DimLvlType::kCompressed) && + !isDimLevelType(b, DimLvlType::kSingleton))) { if (reset) simple.reset(b); reset = true; @@ -290,14 +297,8 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) { BitVector tmp = latPoints[j].bits; tmp ^= latPoints[i].bits; - return !hasAnyDimOf(tmp, kSparse); -} - -bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { - for (unsigned b = 0, be = bits.size(); b < be; b++) - if (bits[b] && isDim(b, d)) - return true; - return false; + return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) && + !hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton); } bool Merger::isSingleCondition(unsigned t, unsigned e) const { @@ -383,6 +384,13 @@ llvm_unreachable("unexpected kind"); } +bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const { + for (unsigned b = 0, be = bits.size(); b < be; b++) + if (bits[b] && isDimLevelType(b, tp)) + return true; + return false; +} + #ifndef NDEBUG //===----------------------------------------------------------------------===// @@ -591,18 +599,23 @@ if (bits[b]) { unsigned t = tensor(b); unsigned i = index(b); + DimLevelFormat f = dims[t][i]; llvm::dbgs() << " i_" << t << "_" << i << "_"; - switch (dims[t][i]) { - case kSparse: - llvm::dbgs() << "S"; - break; - case kDense: + switch (f.levelType) { + case DimLvlType::kDense: llvm::dbgs() << "D"; break; - case kUndef: + case DimLvlType::kCompressed: + llvm::dbgs() << "C"; + break; + case DimLvlType::kSingleton: + llvm::dbgs() << "S"; + break; + case DimLvlType::kUndef: llvm::dbgs() << "U"; break; } + llvm::dbgs() << "[O=" << f.isOrdered << ",U=" << f.isUnique << "]"; } } } @@ -855,9 +868,8 @@ if (isa(def)) return true; // Operation defined outside branch. - if (def->getBlock() != block) { + if (def->getBlock() != block) return def->getBlock() != op->getBlock(); // invariant? - } // Operation defined within branch. Anything is accepted, // as long as all subexpressions are admissable. for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) @@ -1038,7 +1050,6 @@ if (x.has_value() && y.has_value() && z.has_value()) { unsigned e0 = x.value(); unsigned e1 = y.value(); - // unsigned e2 = z.getValue(); if (auto redop = dyn_cast(def)) { if (isAdmissableBranch(redop, redop.getRegion())) return addExp(kReduce, e0, e1, Value(), def); diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -310,15 +310,15 @@ MergerTest3T1L() : MergerTestBase(3, 1) { // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDim(t0, l0, Dim::kSparse); + merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed)); // Tensor 1: sparse input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDim(t1, l0, Dim::kSparse); + merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed)); // Tensor 2: dense output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDim(t2, l0, Dim::kDense); + merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense)); } }; @@ -333,19 +333,19 @@ MergerTest4T1L() : MergerTestBase(4, 1) { // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDim(t0, l0, Dim::kSparse); + merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed)); // Tensor 1: sparse input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDim(t1, l0, Dim::kSparse); + merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed)); // Tensor 2: sparse input vector merger.addExp(Kind::kTensor, t2, -1u); - merger.setDim(t2, l0, Dim::kSparse); + merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed)); // Tensor 3: dense output vector merger.addExp(Kind::kTensor, t3, -1u); - merger.setDim(t3, l0, Dim::kDense); + merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense)); } }; @@ -364,15 +364,15 @@ MergerTest3T1LD() : MergerTestBase(3, 1) { // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDim(t0, l0, Dim::kSparse); + merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed)); // Tensor 1: dense input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDim(t1, l0, Dim::kDense); + merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense)); // Tensor 2: dense output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDim(t2, l0, Dim::kDense); + merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense)); } };