diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -156,7 +156,8 @@ Merger(unsigned t, unsigned l) : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), hasSparseOut(false), - dimTypes(t + 1, std::vector(l, DimLevelType::Undef)) {} + dimTypes(t + 1, std::vector(l, DimLevelType::Undef)), + loopIdxToDim(t + 1, std::vector>(l, llvm::None)) {} /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), @@ -246,7 +247,7 @@ /// Returns true if any set bit corresponds to sparse dimension level type. bool hasAnySparse(const BitVector &bits) const; - /// Gets the dimension level type of the `i`th loop of the `t`th tensor. + /// Gets the dimension level type of the `t`th tensor on `i`th loop. DimLevelType getDimLevelType(unsigned t, unsigned i) const { assert(t < numTensors && i < numLoops); return dimTypes[t][i]; @@ -257,10 +258,35 @@ return getDimLevelType(tensor(b), index(b)); } - /// Sets the dimension level type of the `i`th loop of the `t`th tensor. - void setDimLevelType(unsigned t, unsigned i, DimLevelType d) { - assert(isValidDLT(d)); - dimTypes[t][i] = d; + /// Gets the dimension number of the the `t`th tensor on `i`th loop. + Optional getDimNum(unsigned t, unsigned i) const { + assert(t < numTensors && i < numLoops); + return loopIdxToDim[t][i]; + } + + /// Gets the dimension number of `b`. + Optional getDimNum(unsigned b) const { + return getDimNum(tensor(b), index(b)); + } + + /// Sets the dimension and dimension level type of the `t`th tensor on `i`th + /// loop. + void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim, + DimLevelType dlt) { + assert(isValidDLT(dlt)); + dimTypes[t][i] = dlt; + loopIdxToDim[t][i] = dim; + } + + // Iterates the bits of a lattice, for each set bit, converts it into the + // corresponding tensor dimension and invokes the callback. + void foreachTidDimPairInBits( + const BitVector &bits, + function_ref dim, + DimLevelType dlt)> + cb) { + for (unsigned b : bits.set_bits()) + cb(b, tensor(b), getDimNum(b), getDimLevelType(b)); } // Has sparse output tensor setter. @@ -310,7 +336,11 @@ const unsigned numTensors; const unsigned numLoops; bool hasSparseOut; + // Map that converts pair to the corresponding dimension + // level type. std::vector> dimTypes; + // Map that converts pair to the corresponding dimension. + std::vector>> loopIdxToDim; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector, 8> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -40,8 +40,6 @@ namespace { -constexpr unsigned INVALID_ID = std::numeric_limits::max(); - // Iteration graph sorting. enum SortMask { kSparseOnly = 0x0, @@ -83,14 +81,6 @@ // Topsort (reference should remain in scope). std::vector &topSort; - // From tensor id + loop id => dim id. - // TODO: This map should probably be maintained by Merger (it can be set up - // together with dimLvlType Map). - std::vector> loopIdxToDim; - - // Initialize the above two mapping. - void buildLoopIdxToDimMap(linalg::GenericOp op); - Value getLoopIdxValue(size_t loopIdx) const { for (unsigned lv = 0; lv < topSort.size(); lv++) if (topSort[lv] == loopIdx) @@ -100,30 +90,6 @@ } }; -void CodeGen::buildLoopIdxToDimMap(linalg::GenericOp op) { - size_t numLoops = op.getNumLoops(); - size_t numTensors = op.getNumOperands(); - loopIdxToDim.assign(numTensors, std::vector(numLoops, INVALID_ID)); - - for (OpOperand &t : op->getOpOperands()) { - auto map = op.getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); - // Scan all dimensions of current tensor. - unsigned tid = t.getOperandNumber(); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - auto a = map.getResult(toOrigDim(enc, d)).dyn_cast(); - if (a) { - unsigned loopId = a.getPosition(); - // Fills the mapping. - loopIdxToDim[tid][loopId] = d; - } - // Else a compound affine, do nothing. (at least we are good for - // now, as we only support compound affine expr on non-annoated dense - // tensors). - } - } -} - } // namespace //===----------------------------------------------------------------------===// @@ -151,8 +117,9 @@ /// Helper method to inspect affine expressions. Rejects cases where the /// same index is used more than once. Also rejects compound affine /// expressions in sparse dimensions. -static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, - DimLevelType dim, bool setLvlFormat = true) { +static bool findAffine(Merger &merger, unsigned tensor, unsigned dim, + AffineExpr a, DimLevelType dlt, + bool setLvlFormat = true) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); @@ -160,21 +127,21 @@ return false; // used more than once if (setLvlFormat) - merger.setDimLevelType(tensor, idx, dim); + merger.setDimAndDimLevelType(tensor, idx, dim, dlt); return true; } case AffineExprKind::Add: case AffineExprKind::Mul: { - if (!isDenseDLT(dim)) + if (!isDenseDLT(dlt)) return false; // compound only in dense dim auto binOp = a.cast(); // We do not set dim level format for affine expresssion like d0 + d1 on // both loop index at d0 and d1, - return findAffine(merger, tensor, binOp.getLHS(), dim, false) && - findAffine(merger, tensor, binOp.getRHS(), dim, false); + return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, false) && + findAffine(merger, tensor, dim, binOp.getRHS(), dlt, false); } case AffineExprKind::Constant: - return isDenseDLT(dim); // const only in dense dim + return isDenseDLT(dlt); // const only in dense dim default: return false; } @@ -196,7 +163,7 @@ for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t.getOperandNumber(); AffineExpr a = map.getResult(toOrigDim(enc, d)); - if (!findAffine(merger, tensor, a, getDimLevelType(enc, d))) + if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d))) return false; // inadmissible affine expression } } @@ -1024,8 +991,7 @@ Value clause; if (isCompressedDLT(merger.getDimLevelType(b)) || isSingletonDLT(merger.getDimLevelType(b))) { - auto dim = codegen.loopIdxToDim[tensor][idx]; - assert(dim != INVALID_ID); + auto dim = merger.getDimNum(tensor, idx).value(); Value op1 = codegen.loopEmitter.getCoord()[tensor][dim]; Value op2 = codegen.getLoopIdxValue(idx); clause = builder.create(loc, arith::CmpIPredicate::eq, op1, @@ -1082,23 +1048,22 @@ unsigned l0 = merger.set(lts)[0]; bool needsUniv = false; - SmallVector ts; - SmallVector ds; - for (auto b : merger.lat(l0).bits.set_bits()) { - if (isDenseDLT(merger.getDimLevelType(b)) || - isUndefDLT(merger.getDimLevelType(b))) { - needsUniv = true; - } else { - unsigned tensor = merger.tensor(b); - assert(idx == merger.index(b)); - size_t dim = codegen.loopIdxToDim[tensor][idx]; - assert(dim != INVALID_ID); - ts.push_back(tensor); - ds.push_back(dim); - } - } + SmallVector tids; + SmallVector dims; + merger.foreachTidDimPairInBits( + merger.lat(l0).bits, + [&](unsigned b, unsigned tid, Optional dim, DimLevelType dlt) { + assert(merger.index(b) == idx); + if (isDenseDLT(dlt) || isUndefDLT(dlt)) { + needsUniv = true; + } else { + // sparse/singleton dim levels. + tids.push_back(tid); + dims.push_back(dim.value()); + } + }); - codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), ts, ds); + codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims); // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. @@ -1119,17 +1084,15 @@ SmallVectorImpl &condDims, SmallVectorImpl &extraTids, SmallVectorImpl &extraDims) { - const BitVector &simple = merger.lat(li).simple; const BitVector &all = merger.lat(li).bits; - assert(simple.size() == all.size()); - // First converts bits to array + dim pair - for (unsigned b = 0, e = simple.size(); b < e; b++) { - size_t tid = merger.tensor(b); + const BitVector &simple = merger.lat(li).simple; + + // Converts bits to array + dim pair + merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid, + Optional dim, + DimLevelType dlt) { if (simple.test(b)) { - // the simplified condition must be a subset of the original condition. - assert(all.test(b)); - assert(merger.index(b) == idx); - if (isUndefDLT(merger.getDimLevelType(b))) { + if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to iterate based // on the dim of output tensor. // E.g., this could be a synthetic tensor (for invariants and sparse @@ -1137,26 +1100,28 @@ // out[i][j] = invariant; or a broadcast // out[i][j] = in[i] (j is undef for input) tid = merger.getOutTensorID(); + dim = merger.getDimNum(tid, idx); + // Skips invalid dim (e.g., when this is a zero ranked tensor). + if (!dim) + return; } - auto dim = codegen.loopIdxToDim[tid][idx]; - if (dim != INVALID_ID) { - // dim could be invalid if this is a zero ranked tensor - condTids.push_back(tid); - condDims.push_back(dim); - } - } else if ((all.test(b) || merger.isOutTensor(b, idx)) && - isDenseDLT(merger.getDimLevelType(b))) { - assert(merger.index(b) == idx); - // Note that we generate dense indices of the output tensor - // unconditionally, since they may not appear in the lattice, but may be - // needed for linearized codegen. - // Only dense dimensions should be optimized from conditions. - assert(isDenseDLT(merger.getDimLevelType(b))); - auto dim = codegen.loopIdxToDim[tid][idx]; - assert(dim != INVALID_ID); + condTids.push_back(tid); + condDims.push_back(dim.value()); + } else if (isDenseDLT(dlt)) { + // TODO: get rid of extraTids and extraDims. extraTids.push_back(tid); - extraDims.push_back(dim); + extraDims.push_back(dim.value()); } + }); + + if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) { + // Note that we generate dense indices of the output tensor + // unconditionally, since they may not appear in the lattice, but may be + // needed for linearized codegen. + // Only dense dimensions should be optimized from conditions. + auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value(); + extraTids.push_back(merger.getOutTensorID()); + extraDims.push_back(dim); } } @@ -1370,8 +1335,6 @@ // Recursively generates code if admissible. CodeGen codegen(options, tensors, numTensors, numLoops, sparseOut, outerParNest, topSort); - // TODO: maybe merger should be responsible of maintaining the map. - codegen.buildLoopIdxToDimMap(op); genBuffers(merger, codegen, rewriter, op); genStmt(merger, codegen, rewriter, op, exp, 0); genResult(merger, codegen, rewriter, op); diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -313,15 +313,15 @@ // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelType(t0, l0, DimLevelType::Compressed); + merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed); // Tensor 1: sparse input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelType(t1, l0, DimLevelType::Compressed); + merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed); // Tensor 2: dense output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelType(t2, l0, DimLevelType::Dense); + merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense); } }; @@ -338,19 +338,19 @@ // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelType(t0, l0, DimLevelType::Compressed); + merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed); // Tensor 1: sparse input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelType(t1, l0, DimLevelType::Compressed); + merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed); // Tensor 2: sparse input vector merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelType(t2, l0, DimLevelType::Compressed); + merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed); // Tensor 3: dense output vector merger.addExp(Kind::kTensor, t3, -1u); - merger.setDimLevelType(t3, l0, DimLevelType::Dense); + merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense); } }; @@ -371,15 +371,15 @@ // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelType(t0, l0, DimLevelType::Compressed); + merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed); // Tensor 1: dense input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelType(t1, l0, DimLevelType::Dense); + merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense); // Tensor 2: dense output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelType(t2, l0, DimLevelType::Dense); + merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense); } }; @@ -400,19 +400,19 @@ // Tensor 0: undef input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelType(t0, l0, DimLevelType::Undef); + merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef); // Tensor 1: dense input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelType(t1, l0, DimLevelType::Dense); + merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense); // Tensor 2: undef input vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelType(t2, l0, DimLevelType::Undef); + merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Undef); // Tensor 3: dense output vector. merger.addExp(Kind::kTensor, t3, -1u); - merger.setDimLevelType(t3, l0, DimLevelType::Dense); + merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense); } }; @@ -436,15 +436,15 @@ // Tensor 0: undef input vector. merger.addExp(Kind::kTensor, t0, -1u); - merger.setDimLevelType(t0, l0, DimLevelType::Undef); + merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef); // Tensor 1: undef input vector. merger.addExp(Kind::kTensor, t1, -1u); - merger.setDimLevelType(t1, l0, DimLevelType::Undef); + merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Undef); // Tensor 2: sparse output vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelType(t2, l0, DimLevelType::Compressed); + merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed); } };