diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -46,9 +46,9 @@ linalg::GenericOp op() const { return linalgOp; } const SparsificationOptions &options() const { return sparseOptions; } Merger &merger() { return latticeMerger; } - LoopEmitter *emitter() { return loopEmitter; } + LoopEmitter &emitter() { return loopEmitter; } - void startEmit(OpOperand *so, unsigned lv, LoopEmitter *le); + void startEmit(OpOperand *so, unsigned lv); /// Generates loop boundary statements (entering/exiting loops). The function /// passes and updates the passed-in parameters. @@ -74,9 +74,6 @@ // Topological delegate and sort methods. // - // TODO: get rid of this one! - std::vector &topSortRef() { return topSort; } - size_t topSortSize() const { return topSort.size(); } unsigned topSortAt(unsigned i) const { return topSort.at(i); } void topSortPushBack(unsigned i) { topSort.push_back(i); } @@ -134,9 +131,8 @@ // Merger helper class. Merger latticeMerger; - // Loop emitter helper class (keep reference in scope!). - // TODO: move emitter constructor up in time? - LoopEmitter *loopEmitter; + // Loop emitter helper class. + LoopEmitter loopEmitter; // Topological sort. std::vector topSort; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -19,21 +19,29 @@ unsigned numTensors, unsigned numLoops, unsigned numFilterLoops) : linalgOp(linop), sparseOptions(opts), - latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr), + latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(), redExp(-1u), redCustom(-1u) {} -void CodegenEnv::startEmit(OpOperand *so, unsigned lv, LoopEmitter *le) { - assert(sparseOut == nullptr && loopEmitter == nullptr && - insChain == nullptr && "must only start emitting once"); +void CodegenEnv::startEmit(OpOperand *so, unsigned lv) { + assert(sparseOut == nullptr && insChain == nullptr && + "must only start emitting once"); sparseOut = so; outerParNest = lv; - loopEmitter = le; if (sparseOut) { insChain = sparseOut->get(); latticeMerger.setHasSparseOut(true); } + // Initialize loop emitter. + SmallVector tensors; + for (OpOperand &t : linalgOp->getOpOperands()) + tensors.push_back(t.get()); + loopEmitter.initialize(tensors, + StringAttr::get(linalgOp.getContext(), + linalg::GenericOp::getOperationName()), + /*hasOutput=*/true, + /*isSparseOut=*/sparseOut != nullptr, topSort); } Optional CodegenEnv::genLoopBoundary( @@ -66,13 +74,13 @@ } ArrayRef CodegenEnv::getLoopCurStack() const { - return getTopSortSlice(0, loopEmitter->getCurrentDepth()); + return getTopSortSlice(0, loopEmitter.getCurrentDepth()); } Value CodegenEnv::getLoopIdxValue(size_t loopIdx) const { for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++) if (topSort[lv] == loopIdx) - return loopEmitter->getLoopIV(lv); + return loopEmitter.getLoopIV(lv); llvm_unreachable("invalid loop index"); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -611,7 +611,7 @@ Location loc = op.getLoc(); assert(op.getNumOperands() == op.getNumDpsInputs() + 1); - env.emitter()->initializeLoopEmit( + env.emitter().initializeLoopEmit( builder, loc, /// Generates buffer for the output tensor. /// Note that all sparse kernels assume that when all elements are written @@ -666,16 +666,16 @@ auto enc = getSparseTensorEncoding(t->get().getType()); unsigned rank = map.getNumResults(); if (enc) { - Value pidx = env.emitter()->getPidxs()[tensor].back(); + Value pidx = env.emitter().getPidxs()[tensor].back(); assert(pidx); args.push_back(pidx); // position index } else { for (unsigned d = 0; d < rank; d++) { AffineExpr a = map.getResult(d); - args.push_back(env.emitter()->genAffine(builder, a, op.getLoc())); + args.push_back(env.emitter().genAffine(builder, a, op.getLoc())); } } - return env.emitter()->getValBuffer()[tensor]; + return env.emitter().getValBuffer()[tensor]; } /// Generates insertion code to implement dynamic tensor load. @@ -721,8 +721,8 @@ unsigned rank = op.getRank(t); SmallVector indices; for (unsigned i = 0; i < rank; i++) { - assert(env.emitter()->getLoopIV(i)); - indices.push_back(env.emitter()->getLoopIV(i)); + assert(env.emitter().getLoopIV(i)); + indices.push_back(env.emitter().getLoopIV(i)); } Value chain = env.getInsertionChain(); env.updateInsertionChain( @@ -988,7 +988,7 @@ } else { SmallVector indices; for (unsigned i = 0; i < at; i++) - indices.push_back(env.emitter()->getLoopIV(i)); + indices.push_back(env.emitter().getLoopIV(i)); Value values = env.getExpandValues(); Value filled = env.getExpandFilled(); Value added = env.getExpandAdded(); @@ -1052,11 +1052,11 @@ // Retrieves the affine expression for the filter loop. AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim)); - return env.emitter()->enterFilterLoopOverTensorAtDim(builder, loc, tid, - dim, a, reduc); + return env.emitter().enterFilterLoopOverTensorAtDim(builder, loc, tid, + dim, a, reduc); } - return env.emitter()->enterLoopOverTensorAtDim(builder, loc, tids, dims, - reduc, isParallel); + return env.emitter().enterLoopOverTensorAtDim(builder, loc, tids, dims, + reduc, isParallel); }); assert(loop); return loop; @@ -1069,7 +1069,7 @@ Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { // Construct the while-loop with a parameter for each // index. - return env.emitter()->enterCoIterationOverTensorsAtDims( + return env.emitter().enterCoIterationOverTensorsAtDims( builder, env.op().getLoc(), tids, dims, needsUniv, reduc); }); assert(loop); @@ -1136,7 +1136,7 @@ Value clause; if (isCompressedDLT(env.dlt(b)) || isSingletonDLT(env.dlt(b))) { auto dim = *env.merger().getDimNum(tensor, idx); - Value op1 = env.emitter()->getCoord()[tensor][dim]; + Value op1 = env.emitter().getCoord()[tensor][dim]; Value op2 = env.getLoopIdxValue(idx); clause = builder.create(loc, arith::CmpIPredicate::eq, op1, op2); @@ -1213,7 +1213,7 @@ } }); - env.emitter()->enterNewLoopSeq(builder, env.op().getLoc(), tids, dims); + env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, dims); // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. @@ -1242,7 +1242,7 @@ AffineExpr affine = affines[toOrigDim(enc, i)]; if (isDenseDLT(getDimLevelType(enc, i)) && affine.isa()) - env.emitter()->genDenseAffineAddressAtCurLevel( + env.emitter().genDenseAffineAddressAtCurLevel( builder, op.getLoc(), input->getOperandNumber(), i, affine); else return; // break on first non-dense non-constant level @@ -1367,8 +1367,8 @@ // Emit the for/while-loop control. Operation *loop = genLoop(env, builder, at, needsUniv, tids, dims, isFor); for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { - env.emitter()->genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(), - tid, dim, exp); + env.emitter().genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(), + tid, dim, exp); } // Until now, we have entered every pair in {cond, extra, @@ -1395,7 +1395,7 @@ } env.genLoopBoundary([&](MutableArrayRef reduc) { - env.emitter()->exitCurrentLoop(rewriter, env.op().getLoc(), reduc); + env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc); return std::nullopt; }); @@ -1406,7 +1406,7 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at, unsigned idx, unsigned ldx) { assert(env.getLoopIdxValue(idx) == nullptr); - env.emitter()->exitCurrentLoopSeq(); + env.emitter().exitCurrentLoopSeq(); // Unmark bookkeeping of invariants and loop index. genInvariants(env, builder, exp, ldx, /*atStart=*/false); // Finalize access pattern expansion for sparse tensor output. @@ -1492,7 +1492,7 @@ } else { // To rematerialize an non-annotated tensor, simply load it // from the bufferized value. - Value val = env.emitter()->getValBuffer().back(); // value array + Value val = env.emitter().getValBuffer().back(); // value array rewriter.replaceOpWithNewOp(op, resType, val); } } @@ -1559,20 +1559,8 @@ if (!isAdmissible) return failure(); // inadmissible expression, reject - // Updates environment with a loop emitter. - // TODO: refactor so that emitter can be constructed earlier - // and updating is made easy, i.e. remove this whole block? - SmallVector tensors; - for (OpOperand &t : op->getOpOperands()) - tensors.push_back(t.get()); - LoopEmitter lpe( - tensors, - StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()), - /*hasOutput=*/true, /*isSparseOut=*/sparseOut != nullptr, - env.topSortRef()); - env.startEmit(sparseOut, outerParNest, &lpe); - // Recursively generates code if admissible. + env.startEmit(sparseOut, outerParNest); genBuffers(env, rewriter); genInitConstantDenseAddress(env, rewriter); genStmt(env, rewriter, exp, 0);