diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -430,6 +430,9 @@ coords.push_back(l.iv); } + /// Gets loop induction variable at the given level. + unsigned getCurrentDepth() const { return loopStack.size(); } + /// Gets loop induction variable at the given level. Value getLoopIV(size_t level) const { if (level < loopStack.size()) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -90,6 +90,11 @@ // Topsort (reference should remain in scope). std::vector &topSort; + ArrayRef getLoopCurStack() const { + ArrayRef topSortRef = topSort; + return topSortRef.slice(0, loopEmitter.getCurrentDepth()); + } + Value getLoopIdxValue(size_t loopIdx) const { for (unsigned lv = 0; lv < topSort.size(); lv++) if (topSort[lv] == loopIdx) @@ -134,21 +139,83 @@ // Sparse compiler analysis methods. //===----------------------------------------------------------------------===// +/// Determines if affine expression is invariant. +static bool isInvariantAffine(AffineExpr a, ArrayRef loopStack, + unsigned ldx, bool &atLevel) { + switch (a.getKind()) { + case AffineExprKind::DimId: { + unsigned idx = a.cast().getPosition(); + if (idx == ldx) { + atLevel = true; + // Must be invariant if we are at the level. + return true; + } + bool isInvariant = false; + for (unsigned loop : loopStack) { + isInvariant = (loop == idx); + if (isInvariant) + break; + } + return isInvariant; + } + case AffineExprKind::Add: + case AffineExprKind::Mul: { + auto binOp = a.cast(); + return isInvariantAffine(binOp.getLHS(), loopStack, ldx, atLevel) && + isInvariantAffine(binOp.getRHS(), loopStack, ldx, atLevel); + } + default: { + assert(a.isa()); + return true; + } + } +} + +/// Determines if affine expression is invariant. +static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, + unsigned ldx, bool &atLevel) { + return isInvariantAffine(a, codegen.getLoopCurStack(), ldx, atLevel); +} + /// Helper method to construct a permuted dimension ordering /// that adheres to the given topological sort. -static AffineMap permute(MLIRContext *context, AffineMap m, - std::vector &topSort) { +static AffineMap permute(const Merger &merger, MLIRContext *context, + AffineMap m, ArrayRef topSort) { unsigned sz = topSort.size(); - assert(m.getNumResults() == sz && "TopoSort/AffineMap size mismatch"); + assert(m.getNumDims() + merger.getNumFilterLoops() == sz && + "TopoSort/AffineMap size mismatch"); // Construct the inverse of `m`; to avoid the asymptotic complexity // of calling `m.getPermutedPosition` repeatedly. - SmallVector inv(sz); - for (unsigned i = 0; i < sz; i++) - inv[i] = m.getDimPosition(i); + SmallVector perm; + unsigned numResults = m.getNumResults(); + BitVector worklist(numResults, true); + unsigned loopDepth = 1; + // Construct the permutation. - SmallVector perm(sz); - for (unsigned i = 0; i < sz; i++) - perm[i] = inv[topSort[i]]; + while (worklist.any() && loopDepth <= topSort.size()) { + unsigned preSize = perm.size(); + for (auto dim : worklist.set_bits()) { + bool atLevel = false; + if (m.getResult(dim).isa() || + (isInvariantAffine(m.getResult(dim), topSort.slice(0, loopDepth), + topSort[loopDepth - 1], atLevel) && + atLevel)) { + // If the matching affine is constant expression or just become + // invariant. We can visit the dimension now without breaking the + // topSort constraint. + perm.push_back(dim); + } + } + + // Removes resolved dimension. + for (unsigned i = preSize, e = perm.size(); i < e; i++) + worklist.reset(perm[i]); + + // Tries to entering the next loop level. + loopDepth += 1; + } + + assert(perm.size() == numResults); return AffineMap::getPermutationMap(perm, context); } @@ -422,9 +489,6 @@ auto iteratorTypes = op.getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand &t : op->getOpOperands()) { - // Skip tensor during cycle resolution. - if (&t == skip) - continue; // Get map and encoding. auto map = op.getMatchingIndexingMap(&t); auto enc = getSparseTensorEncoding(t.get().getType()); @@ -453,6 +517,11 @@ ta = AffineExpr(); } + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + continue; + if (d > 0) { AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); Optional fldx = @@ -945,30 +1014,6 @@ return ee; } -/// Determines if affine expression is invariant. -static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, - unsigned ldx, bool &atLevel) { - switch (a.getKind()) { - case AffineExprKind::DimId: { - unsigned idx = a.cast().getPosition(); - if (idx == ldx) { - atLevel = true; - // Must be invariant if we are at the level. - return true; - } - return codegen.getLoopIdxValue(idx) != nullptr; // no longer in play? - } - case AffineExprKind::Add: - case AffineExprKind::Mul: { - auto binOp = a.cast(); - return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && - isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); - } - default: - return true; - } -} - /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned ldx, @@ -1428,7 +1473,6 @@ // Note that we generate dense indices of the output tensor // unconditionally, since they may not appear in the lattice, but may be // needed for linearized codegen. - // Only dense dimensions should be optimized from conditions. auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value(); extraTids.push_back(merger.getOutTensorID()); extraDims.push_back(dim); @@ -1698,7 +1742,7 @@ auto srcTp = tval.getType().cast(); auto dstEnc = SparseTensorEncodingAttr::get( op->getContext(), srcEnc.getDimLevelType(), - permute(getContext(), op.getMatchingIndexingMap(t), + permute(merger, getContext(), op.getMatchingIndexingMap(t), topSort), // new order srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth());