diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -978,7 +978,6 @@ ArrayRef extraTids, ArrayRef extraDims) { Location loc = op.getLoc(); - auto iteratorTypes = op.getIteratorTypesArray(); bool isSparse = isCompressedDLT(merger.getDimLevelType(tid, idx)) || isSingletonDLT(merger.getDimLevelType(tid, idx)); bool isParallel = isParallelFor(codegen, isOuter, isSparse); @@ -1234,21 +1233,28 @@ // TODO: Should we come up with a more adhersive way to handle // constant expression? We now requires two (somehow ad-hoc) code for // it. - assert(false && "do not support constant affine"); - } - - bool atLevel = false; - if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) { - // If the compound affine is invariant and we are right at the - // level. We need to generate the address according to the affine - // expression. This is also the best place we can do it to avoid - // putting it inside inner loops. - // NOTE: It assumes that the levels of the input tensor are - // initialized in order, another more admissible approach might be - // accepting out-of-order access between consecutive dense levels. - affineTids.push_back(tid); - affineDims.push_back(i); - exps.push_back(exp); + if (i != 0 && // i == 0 cases are handled in genConstantDenseAddress + ((!affineTids.empty() && affineTids.back() == tid && + affineDims.back() == i - 1) || // Condition 1 + merger.getLoopIdx(tid, i - 1) == idx)) { // Condition 2 + affineTids.push_back(tid); + affineDims.push_back(i); + exps.push_back(exp); + } + } else { + bool atLevel = false; + if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) { + // If the compound affine is invariant and we are right at the + // level. We need to generate the address according to the affine + // expression. This is also the best place we can do it to avoid + // putting it inside inner loops. + // NOTE: It assumes that the levels of the input tensor are + // initialized in order, another more admissible approach might be + // accepting out-of-order access between consecutive dense levels. + affineTids.push_back(tid); + affineDims.push_back(i); + exps.push_back(exp); + } } } } @@ -1413,12 +1419,37 @@ } } +static void genConstantDenseAddress(CodeGen &codegen, RewriterBase &rewriter, + linalg::GenericOp op) { + // We can generates address for constant affine expression before any loops + // starting from the first level as they do not depend on any thing. + // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two + // levels can be determined before loops. + for (OpOperand *input : op.getDpsInputOperands()) { + ArrayRef affines = + op.getMatchingIndexingMap(input).getResults(); + auto enc = getSparseTensorEncoding(input->get().getType()); + if (enc) { + for (unsigned i = 0, e = affines.size(); i < e; i++) { + AffineExpr affine = affines[toOrigDim(enc, i)]; + if (isDenseDLT(getDimLevelType(enc, i)) && + affine.isa()) { + codegen.loopEmitter.genDenseAffineAddressAtCurLevel( + rewriter, op.getLoc(), input->getOperandNumber(), i, affine); + } else { + // Breaks on first non-dense non-constant level. + break; + } + } + } + } +} + //===----------------------------------------------------------------------===// // Sparse compiler rewriting methods. //===----------------------------------------------------------------------===// namespace { - /// Sparse rewriting rule for generic Lingalg operation. struct GenericOpSparsifier : public OpRewritePattern { public: @@ -1486,6 +1517,7 @@ CodeGen codegen(options, tensors, numTensors, numLoops, sparseOut, outerParNest, topSort); genBuffers(merger, codegen, rewriter, op); + genConstantDenseAddress(codegen, rewriter, op); genStmt(merger, codegen, rewriter, op, exp, 0); genResult(merger, codegen, rewriter, op); return success();