diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -156,8 +156,12 @@ Merger(unsigned t, unsigned l) : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), hasSparseOut(false), - dimTypes(t + 1, std::vector(l, DimLevelType::Undef)), - loopIdxToDim(t + 1, std::vector>(l, llvm::None)) {} + dimTypes(numTensors, + std::vector(numLoops, DimLevelType::Undef)), + loopIdxToDim(numTensors, + std::vector>(numLoops, llvm::None)), + dimToLoopIdx(numTensors, + std::vector>(numLoops, llvm::None)) {} /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), @@ -258,6 +262,11 @@ return getDimLevelType(tensor(b), index(b)); } + Optional getLoopIdx(unsigned t, unsigned dim) const { + assert(t < numTensors && dim < numLoops); + return dimToLoopIdx[t][dim]; + } + /// Gets the dimension number of the the `t`th tensor on `i`th loop. Optional getDimNum(unsigned t, unsigned i) const { assert(t < numTensors && i < numLoops); @@ -276,6 +285,8 @@ assert(isValidDLT(dlt)); dimTypes[t][i] = dlt; loopIdxToDim[t][i] = dim; + assert(dim < numLoops); + dimToLoopIdx[t][dim] = i; } // Iterates the bits of a lattice, for each set bit, converts it into the @@ -341,6 +352,8 @@ std::vector> dimTypes; // Map that converts pair to the corresponding dimension. std::vector>> loopIdxToDim; + // Map that converts pair to the corresponding loop id. + std::vector>> dimToLoopIdx; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -318,7 +318,6 @@ // loopEmiter.exitCurrentLoop(); // exit i //===----------------------------------------------------------------------===// -// TODO: Sparsification should also rely on this class to generate loops. class SparseTensorLoopEmitter { public: /// Optional callback function to setup dense output tensors when @@ -381,6 +380,11 @@ ArrayRef extraTids = {}, ArrayRef extraDims = {}); + + void genDenseAffineAddressAtCurLevel(OpBuilder &builder, Location loc, + size_t tid, size_t dim, + AffineExpr affine); + /// Emits a co-iteration loop over a set of tensors. Operation *enterCoIterationOverTensorsAtDims( OpBuilder &builder, Location loc, ArrayRef tids, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -328,6 +328,14 @@ return loop; } + +void SparseTensorLoopEmitter::genDenseAffineAddressAtCurLevel( + OpBuilder &builder, Location loc, size_t tid, size_t dim, + AffineExpr affine) { + Value affineV = genAffine(builder, affine, loc); + pidxs[tid][dim] = genAddress(builder, loc, tid, dim, affineV); +} + Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims( OpBuilder &builder, Location loc, ArrayRef tids, ArrayRef dims, bool needsUniv, MutableArrayRef reduc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/SmallBitVector.h" @@ -94,6 +95,27 @@ } }; +class ParallelAffineDimFinder + : public AffineExprVisitor { + AffineExpr paraDim; + utils::IteratorType pickIterType; + SmallVector iterTypes; + +public: + explicit ParallelAffineDimFinder(linalg::GenericOp op) + : iterTypes(op.getIteratorTypesArray()) {} + void visitDimExpr(AffineDimExpr expr) { + if (paraDim == nullptr || pickIterType == iterTypes[expr.getPosition()]) { + paraDim = expr; + } + } + + void setPickedIterType(utils::IteratorType iterType) { + pickIterType = iterType; + } + + AffineDimExpr getDimExpr() const { return paraDim.cast(); } +}; } // namespace //===----------------------------------------------------------------------===// @@ -215,25 +237,44 @@ /// Helper method to add all constraints from the indices in one affine /// expression before all indices in the other affine expression. For /// example i0+i1 < i2+i3+1 yields i0 i0 < fidx, i1 < fidx. +/// The affine expression `b` is empty iff `tidx` have a value, leading to +/// tidx < a = (i0 + i1) => tidx < i0, tidx < i1. static void addAffineOrderings(std::vector> &adjM, std::vector &inDegree, AffineExpr a, - AffineExpr b, unsigned fidx) { - switch (a.getKind()) { - case AffineExprKind::DimId: { - unsigned idx = a.cast().getPosition(); - if (b) - addAffineOrderings(adjM, inDegree, b, AffineExpr(), idx); - else if (!adjM[fidx][idx]) { - adjM[fidx][idx] = true; - inDegree[idx]++; + AffineExpr b, Optional fidx, + Optional tidx) { + if (!a && !b) { + // Recursion leaf. + assert(fidx && tidx); + unsigned f = *fidx, t = *tidx; + if (!adjM[f][t]) { + adjM[f][t] = true; + inDegree[t]++; } + return; + } + auto toExpand = a ? a : b; + switch (toExpand.getKind()) { + case AffineExprKind::DimId: { + auto idx = toExpand.cast().getPosition(); + if (toExpand == a) + addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx); + else // toExpand == b + addAffineOrderings(adjM, inDegree, a, AffineExpr(), fidx, idx); break; } case AffineExprKind::Add: case AffineExprKind::Mul: { - auto binOp = a.cast(); - addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx); - addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx); + auto binOp = toExpand.cast(); + if (toExpand == a) { + addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx, tidx); + addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx, tidx); + } else { + addAffineOrderings(adjM, inDegree, a, binOp.getLHS(), fidx, tidx); + addAffineOrderings(adjM, inDegree, a, binOp.getRHS(), fidx, tidx); + } break; } default: @@ -271,10 +312,53 @@ // by default) puts an ordering constraint on the loop indices. For // example, the tensor expresion A_ijk forces the ordering i < j < k // on the loop indices if no explicit dimension ordering is given. - for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { - AffineExpr f = map.getResult(toOrigDim(enc, d - 1)); - AffineExpr t = map.getResult(toOrigDim(enc, d)); - addAffineOrderings(adjM, inDegree, f, t, 0); + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + AffineExpr ta = map.getResult(toOrigDim(enc, d)); + Optional tldx = merger.getLoopIdx(t.getOperandNumber(), d); + + if (d > 0) { + AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); + Optional fldx = + merger.getLoopIdx(t.getOperandNumber(), d - 1); + + if (!(mask & SortMask::kIncludeDense) && !tldx) { + ParallelAffineDimFinder finder(op); + // e.g, for [dense, dense] -> (d0 + d1, d2 + d3) + // It is totally fine to have loop sequence d0->d2->d1->d3 instead of + // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. + // We use a heuristic here to only pick one dim expression from each + // compound affine expression to establish the order between two dense + // dimensions. + // NOTE: The ordering can only be loosen when the destination level is + // dense, for [dense, sparse] -> (d0 + d1, d2), we still require both + // d0 < d2 and d1 < d2 to ensure correct ordering (i.e., no ordering + // like d0->d2->d1). + // TODO: this is obviously a sub optimal solution. + if (!fldx && fa.isa()) { + assert(isDenseDLT(getDimLevelType(enc, d - 1)) && + !fa.isa()); + // Heuristic: we prefer parallel loop for lhs to reduce the chance + // we add reduce < parallel ordering. + finder.setPickedIterType(utils::IteratorType::parallel); + finder.walkPostOrder(fa); + fa = finder.getDimExpr(); + fldx = finder.getDimExpr().getPosition(); + } + if (!ta.isa()) { + // Dense compound affine + assert(isDenseDLT(getDimLevelType(enc, d)) && + !ta.isa()); + // Heuristic: we prefer reduction loop for rhs to reduce the chance + // addint reduce < parallel ordering. + finder.setPickedIterType(utils::IteratorType::reduction); + finder.walkPostOrder(ta); + ta = finder.getDimExpr(); + tldx = finder.getDimExpr().getPosition(); + } + } + + addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); + } } // Push unrelated loops into sparse iteration space, so these // will be skipped more often. @@ -638,8 +722,8 @@ // Select operation insertion. Value insChain = codegen.insChain; assert(insChain); - scf::IfOp ifOp = builder.create( - loc, insChain.getType(), rhs, /*else=*/true); + scf::IfOp ifOp = builder.create(loc, insChain.getType(), rhs, + /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Existing value was preserved to be used here. assert(merger.exp(exp).val); @@ -1086,12 +1170,13 @@ return false; } -static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen, - unsigned li, unsigned idx, - SmallVectorImpl &condTids, - SmallVectorImpl &condDims, - SmallVectorImpl &extraTids, - SmallVectorImpl &extraDims) { +static void translateBitsToTidDimPairs( + Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li, + unsigned idx, SmallVectorImpl &condTids, + SmallVectorImpl &condDims, SmallVectorImpl &extraTids, + SmallVectorImpl &extraDims, SmallVectorImpl &affineTids, + SmallVectorImpl &affineDims, SmallVectorImpl &exps) { + const BitVector &all = merger.lat(li).bits; const BitVector &simple = merger.lat(li).simple; @@ -1119,6 +1204,53 @@ // TODO: get rid of extraTids and extraDims. extraTids.push_back(tid); extraDims.push_back(dim.value()); + } else { + assert(isUndefDLT(dlt)); + if (tid >= op.getNumDpsInputs()) + // We only handle affine expression on input tensors (for now). + return; + OpOperand *operand = &op->getOpOperand(tid); + auto enc = getSparseTensorEncoding(operand->get().getType()); + if (!enc) + // Non-annotated dense tensors requires no special handling. + return; + + ArrayRef affines = + op.getMatchingIndexingMap(operand).getResults(); + assert(affines.size() == enc.getDimLevelType().size()); + for (unsigned i = 0, e = affines.size(); i < e; i++) { + AffineExpr exp = affines[toOrigDim(enc, i)]; + if (exp.isa() || !isDenseDLT(getDimLevelType(enc, i))) + // Skip simple affine expression and non dense dimensions (which has + // it own filter loop). + continue; + + // Constant affine expressions on dense level required to be generated + // when + // 1. The previous level is an (at-level) invariant compound dense + // affine (with no corresponding loop idx); or + // 2. The previous level is being generated right now. + if (exp.isa()) { + // TODO: Should we come up with a more adhersive way to handle + // constant expression? We now requires two (somehow ad-hoc) code for + // it. + assert(false && "do not support constant affine"); + } + + bool atLevel = false; + if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) { + // If the compound affine is invariant and we are right at the + // level. We need to generate the address according to the affine + // expression. This is also the best place we can do it to avoid + // putting it inside inner loops. + // NOTE: It assumes that the levels of the input tensor are + // initialized in order, another more admissible approach might be + // accepting out-of-order access between consecutive dense levels. + affineTids.push_back(tid); + affineDims.push_back(i); + exps.push_back(exp); + } + } } }); @@ -1142,12 +1274,23 @@ // The set of (dense) tensors that is optimized from condition, yet still // need extra locals to iterate on them. SmallVector extraTids, extraDims; - - translateBitsToTidDimPairs(merger, codegen, li, codegen.topSort[at], condTids, - condDims, extraTids, extraDims); + // The set of dense tensors with non-trivial affine expression that just + // becomes invariant and the address shall now be generated at the current + // level. + SmallVector affineTids, affineDims; + SmallVector affines; + + translateBitsToTidDimPairs(merger, codegen, op, li, codegen.topSort[at], + condTids, condDims, extraTids, extraDims, + affineTids, affineDims, affines); // Emit the for/while-loop control. Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv, condTids, condDims, extraTids, extraDims); + + for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { + codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(), + tid, dim, exp); + } return loop; }