diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -623,11 +623,10 @@ /// The function will also perform in-place update on the `reduc` vector to /// return the reduction variable used inside the generated loop. Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc, - size_t tid, size_t dim, + ArrayRef tids, + ArrayRef dims, MutableArrayRef reduc = {}, - bool isParallel = false, - ArrayRef extraTids = {}, - ArrayRef extraDims = {}); + bool isParallel = false); Operation *enterFilterLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid, size_t dim, @@ -641,8 +640,7 @@ /// Emits a co-iteration loop over a set of tensors. Operation *enterCoIterationOverTensorsAtDims( OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef dims, bool needsUniv, MutableArrayRef reduc = {}, - ArrayRef extraTids = {}, ArrayRef extraDims = {}); + ArrayRef dims, bool needsUniv, MutableArrayRef reduc = {}); void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc = {}); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -380,22 +380,32 @@ } Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim( - OpBuilder &builder, Location loc, size_t tid, size_t dim, - MutableArrayRef reduc, bool isParallel, ArrayRef extraTids, - ArrayRef extraDims) { - - assert(dimTypes[tid].size() > dim); - // We can not re-enter the same level. - assert(!coord[tid][dim]); + OpBuilder &builder, Location loc, ArrayRef tids, + ArrayRef dims, MutableArrayRef reduc, bool isParallel) { // TODO: support multiple return on parallel for? assert(!isParallel || reduc.size() <= 1); - Value step = constantIndex(builder, loc, 1); - auto dimType = dimTypes[tid][dim]; - bool isSparseInput = isCompressedDLT(dimType) || isSingletonDLT(dimType); - assert(isDenseDLT(dimType) || isCompressedDLT(dimType) || - isSingletonDLT(dimType)); + bool isSparseInput = false; + size_t tid = tids.front(), dim = dims.front(); + for (auto [t, d] : llvm::zip(tids, dims)) { + assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair + assert(!coord[t][d]); // We cannot re-enter the same level + auto dimType = dimTypes[t][d]; + // Must be a recognizable DLT. + assert(isDenseDLT(dimType) || isCompressedDLT(dimType) || + isSingletonDLT(dimType)); + bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType); + // We can at most have one sparse input, otherwise, a while loop is required + // to co-iterate multiple sparse tensors. + assert(!isSparseInput || !isSparse); + if (isSparse) { + tid = t; + dim = d; + } + isSparseInput = isSparseInput || isSparse; + } + Value step = constantIndex(builder, loc, 1); Value lo = isSparseInput ? pidxs[tid][dim] // current offset : loopSeqStack.back(); // univeral tid Value hi = highs[tid][dim]; @@ -439,18 +449,13 @@ } else { // Dense tensor, the coordinates is the inducation variable. coord[tid][dim] = iv; - // generate pidx for dense dim (pidx = i * sz + j) - auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc && !isSparseOutput(tid)) - pidxs[tid][dim] = genAddress(builder, loc, tid, dim, iv); } - - // NOTE: we can also prepares for next dim here in advance + // NOTE: we can also prepare for next dim here in advance // Push the loop into stack loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), loop, coord[tid][dim], loopTag); // Emit extra locals. - emitExtraLocalsForTensorsAtDenseDims(builder, loc, extraTids, extraDims); + emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims); return loop; } @@ -531,8 +536,7 @@ Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims( OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef dims, bool needsUniv, MutableArrayRef reduc, - ArrayRef extraTids, ArrayRef extraDims) { + ArrayRef dims, bool needsUniv, MutableArrayRef reduc) { assert(tids.size() == dims.size()); SmallVector types; SmallVector operands; @@ -611,24 +615,12 @@ min = after->getArguments().back(); } - for (auto [tid, dim] : llvm::zip(tids, dims)) { - // All dense dim (as well as sparse output tensor) shared the same pidx in - // the while loop. - if (isDenseDLT(dimTypes[tid][dim])) { - pidxs[tid][dim] = min; - // generate pidx for dense dim (pidx = i * sz + j) - auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc && !isSparseOutput(tid)) - pidxs[tid][dim] = genAddress(builder, loc, tid, dim, min); - } - // NOTE: we can also prepares for next dim here in advance - } // Sets up the loop stack. loopStack.emplace_back(tids, dims, whileOp, min, loopTag); assert(loopStack.size() == loopSeqStack.size()); // Emits extra locals - emitExtraLocalsForTensorsAtDenseDims(builder, loc, extraTids, extraDims); + emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims); // Updates reduction variables assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0)); @@ -682,18 +674,20 @@ // output tensor unconditionally, since they may not appear in the lattice, // but may be needed for linearized codegen. for (auto [tid, dim] : llvm::zip(tids, dims)) { - assert(isDenseDLT(dimTypes[tid][dim])); - auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc && !isSparseOutput(tid)) { - bool validPidx = dim == 0 || pidxs[tid][dim - 1]; - if (!validPidx) { - // We might not find the pidx for the sparse output tensor as it is - // unconditionally required by the sparsification. - assert(isOutputTensor(tid)); - continue; + if (isDenseDLT(dimTypes[tid][dim])) { + auto enc = getSparseTensorEncoding(tensors[tid].getType()); + if (enc && !isSparseOutput(tid)) { + bool validPidx = dim == 0 || pidxs[tid][dim - 1]; + if (!validPidx) { + // We might not find the pidx for the sparse output tensor as it is + // unconditionally required by the sparsification. + assert(isOutputTensor(tid)); + continue; + } + pidxs[tid][dim] = + genAddress(builder, loc, tid, dim, loopStack.back().iv); + // NOTE: we can also prepares for next dim here in advance } - pidxs[tid][dim] = genAddress(builder, loc, tid, dim, loopStack.back().iv); - // NOTE: we can also prepares for next dim here in advance } } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1028,21 +1028,24 @@ /// Generates a for-loop on a single index. static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter, - bool isInner, unsigned idx, size_t tid, size_t dim, - ArrayRef extraTids, - ArrayRef extraDims) { + bool isInner, unsigned idx, ArrayRef tids, + ArrayRef dims) { linalg::GenericOp op = env.op(); Location loc = op.getLoc(); auto iteratorTypes = op.getIteratorTypesArray(); - bool isSparse = - isCompressedDLT(env.dlt(tid, idx)) || isSingletonDLT(env.dlt(tid, idx)); + bool isSparse = llvm::any_of(tids, [idx, &env](size_t tid) { + return isCompressedDLT(env.dlt(tid, idx)) || + isSingletonDLT(env.dlt(tid, idx)); + }); + bool isParallel = isParallelFor(env, isOuter, isSparse); Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { if (env.merger().isFilterLoop(idx)) { - // extraTids/extraDims must be empty because filter loops only + size_t tid = tids.front(), dim = dims.front(); + // tids/dims must only have one value because filter loops only // corresponding to the one and only sparse tensor level. - assert(isSparse && extraTids.empty() && extraDims.empty()); + assert(isSparse && tids.size() == 1 && dims.size() == 1); OpOperand *t = &op->getOpOperand(tid); auto enc = getSparseTensorEncoding(t->get().getType()); // Retrieves the affine expression for the filter loop. @@ -1051,8 +1054,8 @@ return env.emitter()->enterFilterLoopOverTensorAtDim(builder, loc, tid, dim, a, reduc); } - return env.emitter()->enterLoopOverTensorAtDim( - builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims); + return env.emitter()->enterLoopOverTensorAtDim(builder, loc, tids, dims, + reduc, isParallel); }); assert(loop); return loop; @@ -1060,16 +1063,13 @@ /// Emit a while-loop for co-iteration over multiple indices. static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx, - bool needsUniv, ArrayRef condTids, - ArrayRef condDims, - ArrayRef extraTids, - ArrayRef extraDims) { + bool needsUniv, ArrayRef tids, + ArrayRef dims) { Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { // Construct the while-loop with a parameter for each // index. return env.emitter()->enterCoIterationOverTensorsAtDims( - builder, env.op().getLoc(), condTids, condDims, needsUniv, reduc, - extraTids, extraDims); + builder, env.op().getLoc(), tids, dims, needsUniv, reduc); }); assert(loop); return loop; @@ -1078,20 +1078,16 @@ /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at, - bool needsUniv, ArrayRef condTids, - ArrayRef condDims, ArrayRef extraTids, - ArrayRef extraDims) { - assert(condTids.size() == condDims.size()); - assert(extraTids.size() == extraDims.size()); + bool needsUniv, ArrayRef tids, + ArrayRef dims, bool isFor) { + assert(tids.size() == dims.size()); unsigned idx = env.topSortAt(at); - if (condTids.size() == 1) { + if (isFor) { bool isOuter = at == 0; bool isInner = at == env.topSortSize() - 1; - return genFor(env, builder, isOuter, isInner, idx, condTids.front(), - condDims.front(), extraTids, extraDims); + return genFor(env, builder, isOuter, isInner, idx, tids, dims); } - return genWhile(env, builder, idx, needsUniv, condTids, condDims, extraTids, - extraDims); + return genWhile(env, builder, idx, needsUniv, tids, dims); } /// Generates the induction structure for a while-loop. @@ -1263,15 +1259,15 @@ genConstantDenseAddressFromLevel(env, rewriter, tid, 0); } -static void translateBitsToTidDimPairs( - CodegenEnv &env, unsigned li, unsigned idx, - SmallVectorImpl &condTids, SmallVectorImpl &condDims, - SmallVectorImpl &extraTids, SmallVectorImpl &extraDims, - SmallVectorImpl &affineTids, SmallVectorImpl &affineDims, - SmallVectorImpl &exps) { +/// Return true if the lattices bit can be iterated by a for loop. +static bool translateBitsToTidDimPairs( + CodegenEnv &env, unsigned li, unsigned idx, SmallVectorImpl &tids, + SmallVectorImpl &dims, SmallVectorImpl &affineTids, + SmallVectorImpl &affineDims, SmallVectorImpl &exps) { const BitVector &all = env.lat(li).bits; const BitVector &simple = env.lat(li).simple; + unsigned numloopCond = 0; // Converts bits to array + dim pair env.merger().foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid, Optional dim, @@ -1290,12 +1286,12 @@ if (!dim) return; } - condTids.push_back(tid); - condDims.push_back(*dim); + tids.push_back(tid); + dims.push_back(*dim); + numloopCond++; } else if (isDenseDLT(dlt)) { - // TODO: get rid of extraTids and extraDims. - extraTids.push_back(tid); - extraDims.push_back(*dim); + tids.push_back(tid); + dims.push_back(*dim); } else { assert(isUndefDLT(dlt)); linalg::GenericOp op = env.op(); @@ -1344,31 +1340,31 @@ // unconditionally, since they may not appear in the lattice, but may be // needed for linearized env. auto dim = *env.merger().getDimNum(env.merger().getOutTensorID(), idx); - extraTids.push_back(env.merger().getOutTensorID()); - extraDims.push_back(dim); + tids.push_back(env.merger().getOutTensorID()); + dims.push_back(dim); } + + assert(numloopCond > 0); + // If we just need to one loop conditions, the loop can be generated by a for + // loop. + return numloopCond == 1; } /// Starts a single loop in current sequence. static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at, unsigned li, bool needsUniv) { // The set of tensors + dims to generate loops on - SmallVector condTids, condDims; - // The set of (dense) tensors that is optimized from condition, yet still - // need extra locals to iterate on them. - SmallVector extraTids, extraDims; + SmallVector tids, dims; // The set of dense tensors with non-trivial affine expression that just // becomes invariant and the address shall now be generated at the current // level. SmallVector affineTids, affineDims; SmallVector affines; - translateBitsToTidDimPairs(env, li, env.topSortAt(at), condTids, condDims, - extraTids, extraDims, affineTids, affineDims, - affines); + bool isFor = translateBitsToTidDimPairs( + env, li, env.topSortAt(at), tids, dims, affineTids, affineDims, affines); // Emit the for/while-loop control. - Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims, - extraTids, extraDims); + Operation *loop = genLoop(env, builder, at, needsUniv, tids, dims, isFor); for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { env.emitter()->genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(), tid, dim, exp); @@ -1377,8 +1373,8 @@ // Until now, we have entered every pair in {cond, extra, // affine}Tids/Dims. The addresses of the upcoming levels which are dependent // on constant affines expression may now be determined. - auto allTids = llvm::concat(condTids, extraTids, affineTids); - auto allDims = llvm::concat(condDims, extraDims, affineDims); + auto allTids = llvm::concat(tids, affineTids); + auto allDims = llvm::concat(dims, affineDims); for (auto [tid, dim] : llvm::zip(allTids, allDims)) { if (tid != env.merger().getOutTensorID()) genConstantDenseAddressFromLevel(env, builder, tid, dim + 1);