diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -150,33 +150,28 @@ public: /// Constructs a merger for the given number of tensors, native loops, and /// filter loops. The user supplies the number of tensors involved in the - /// kernel, with the last tensor in this set denoting the output tensor. The - /// merger adds an additional synthetic tensor at the end of this set to - /// represent all invariant expressions in the kernel. - /// In addition to natives - /// loops (which are specified by the GenericOp), extra filter loops are - /// needed in order to handle affine expressions on sparse dimensions. - /// E.g., (d0, d1, d2) => (d0 + d1, d2), a naive implementation of the filter - /// loop could be generated as: + /// kernel, with the last tensor in this set denoting the output tensor. + /// The merger adds an additional synthetic tensor at the end of this set + /// to represent all invariant expressions in the kernel. + /// + /// In addition to natives loops (which are specified by the GenericOp), + /// extra filter loops are needed in order to handle affine expressions on + /// sparse dimensions. E.g., (d0, d1, d2) => (d0 + d1, d2), a naive + /// implementation of the filter loop could be generated as + /// /// for (coord : sparse_dim[0]) /// if (coord == d0 + d1) { /// generated_code; /// } /// } - /// to filter out coordinates that are not equal to the affine expression - /// result. + /// + /// to filter out coordinates that are not equal to the affine expression. + /// /// TODO: we want to make the filter loop more efficient in the future, e.g., /// by avoiding scanning the full stored index sparse (keeping the last /// position in ordered list) or even apply binary search to find the index. - Merger(unsigned t, unsigned l, unsigned fl) - : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), - numNativeLoops(l), numLoops(l + fl), hasSparseOut(false), - dimTypes(numTensors, - std::vector(numLoops, DimLevelType::Undef)), - loopIdxToDim(numTensors, - std::vector>(numLoops, std::nullopt)), - dimToLoopIdx(numTensors, - std::vector>(numLoops, std::nullopt)) {} + /// + Merger(unsigned t, unsigned l, unsigned fl); /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), @@ -386,14 +381,18 @@ const unsigned numNativeLoops; const unsigned numLoops; bool hasSparseOut; + // Map that converts pair to the corresponding dimension // level type. std::vector> dimTypes; + // Map that converts pair to the corresponding // dimension. std::vector>> loopIdxToDim; + // Map that converts pair to the corresponding loop id. std::vector>> dimToLoopIdx; + llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -205,6 +205,16 @@ LatPoint::LatPoint(const BitVector &b, unsigned e) : bits(b), exp(e) {} +Merger::Merger(unsigned t, unsigned l, unsigned fl) + : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), + numNativeLoops(l), numLoops(l + fl), hasSparseOut(false), + dimTypes(numTensors, + std::vector(numLoops, DimLevelType::Undef)), + loopIdxToDim(numTensors, + std::vector>(numLoops, std::nullopt)), + dimToLoopIdx(numTensors, + std::vector>(numLoops, std::nullopt)) {} + //===----------------------------------------------------------------------===// // Lattice methods. //===----------------------------------------------------------------------===// @@ -740,17 +750,7 @@ unsigned t = tensor(b); unsigned i = index(b); DimLevelType dlt = dimTypes[t][i]; - llvm::dbgs() << " i_" << t << "_" << i << "_"; - if (isDenseDLT(dlt)) - llvm::dbgs() << "D"; - else if (isCompressedDLT(dlt)) - llvm::dbgs() << "C"; - else if (isSingletonDLT(dlt)) - llvm::dbgs() << "S"; - else if (isUndefDLT(dlt)) - llvm::dbgs() << "U"; - llvm::dbgs() << "[O=" << isOrderedDLT(dlt) << ",U=" << isUniqueDLT(dlt) - << "]"; + llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); } } }