diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -28,59 +28,87 @@ /// sparsification. This environment simplifies passing around such /// data during sparsification (rather than passing around all the /// individual compoments where needed). Furthermore, it provides -/// a number of delegate and convience methods that keep some of the -/// implementation details transparent to sparsification. +/// convience methods that keep implementation details transparent +/// to sparsification while asserting on internal consistency. class CodegenEnv { public: + /// Constructs a code generation environment which can be + /// passed around during sparsification for bookkeeping + /// together with some consistency asserts. CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, unsigned numTensors, unsigned numLoops, unsigned numFilterLoops); - // Start emitting. - void startEmit(SparseTensorLoopEmitter *le); + // + // General methods. + // + + linalg::GenericOp op() const { return linalgOp; } + const SparsificationOptions &options() const { return sparseOptions; } + Merger &merger() { return latticeMerger; } + SparseTensorLoopEmitter *emitter() { return loopEmitter; } + + void startEmit(OpOperand *so, unsigned lv, SparseTensorLoopEmitter *le); - // Delegate methods to merger. - TensorExp &exp(unsigned e) { return merger.exp(e); } - LatPoint &lat(unsigned l) { return merger.lat(l); } - SmallVector &set(unsigned s) { return merger.set(s); } - DimLevelType dimLevelType(unsigned t, unsigned i) const { - return merger.getDimLevelType(t, i); + // + // Merger delegates. + // + + TensorExp &exp(unsigned e) { return latticeMerger.exp(e); } + LatPoint &lat(unsigned l) { return latticeMerger.lat(l); } + SmallVector &set(unsigned s) { return latticeMerger.set(s); } + DimLevelType dlt(unsigned t, unsigned i) const { + return latticeMerger.getDimLevelType(t, i); } - DimLevelType dimLevelType(unsigned b) const { - return merger.getDimLevelType(b); + DimLevelType dlt(unsigned b) const { + return latticeMerger.getDimLevelType(b); } - bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); } - // Delegate methods to loop emitter. - Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); } - const std::vector &getValBuffer() const { - return loopEmitter->getValBuffer(); - } + // + // Topological delegate and sort methods. + // - // Convenience method to slice topsort. - ArrayRef getTopSortSlice(size_t n, size_t m) const { - return ArrayRef(topSort).slice(n, m); - } + // TODO: get rid of this one! + std::vector &topSortRef() { return topSort; } - // Convenience method to get current loop stack. - ArrayRef getLoopCurStack() const { - return getTopSortSlice(0, loopEmitter->getCurrentDepth()); + size_t topSortSize() const { return topSort.size(); } + unsigned topSortAt(unsigned i) const { return topSort.at(i); } + void topSortPushBack(unsigned i) { topSort.push_back(i); } + void topSortClear(unsigned capacity = 0) { + topSort.clear(); + topSort.reserve(capacity); } - // Convenience method to get the IV of the given loop index. - Value getLoopIdxValue(size_t loopIdx) const { - for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++) - if (topSort[lv] == loopIdx) - return getLoopIV(lv); - llvm_unreachable("invalid loop index"); - } + ArrayRef getTopSortSlice(size_t n, size_t m) const; + ArrayRef getLoopCurStack() const; + Value getLoopIdxValue(size_t loopIdx) const; + + // + // Sparse tensor output and expansion methods. + // + + bool hasSparseOutput() const { return sparseOut != nullptr; } + bool isSparseOutput(OpOperand *o) const { return sparseOut == o; } + + Value getInsertionChain() const { return insChain; } + void updateInsertionChain(Value chain); + + bool atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const; + void startExpand(Value values, Value filled, Value added, Value count); + bool isExpand() const { return expValues != nullptr; } + void updateExpandCount(Value count); + Value getExpandValues() const { return expValues; } + Value getExpandFilled() const { return expFilled; } + Value getExpandAdded() const { return expAdded; } + Value getExpandCount() const { return expCount; } + void endExpand(); // - // Reductions. + // Reduction methods. // void startReduc(unsigned exp, Value val); - void updateReduc(Value val); bool isReduc() const { return redExp != -1u; } + void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); @@ -89,39 +117,34 @@ Value getCustomRedId(); void endCustomReduc(); -public: - // - // TODO make this section private too, using similar refactoring as for reduc - // - +private: // Linalg operation. linalg::GenericOp linalgOp; // Sparsification options. - SparsificationOptions options; - - // Topological sort. - std::vector topSort; + SparsificationOptions sparseOptions; // Merger helper class. - Merger merger; + Merger latticeMerger; // Loop emitter helper class (keep reference in scope!). // TODO: move emitter constructor up in time? SparseTensorLoopEmitter *loopEmitter; + // Topological sort. + std::vector topSort; + // Sparse tensor as output. Implemented either through direct injective - // insertion in lexicographic index order or through access pattern expansion - // in the innermost loop nest (`expValues` through `expCount`). + // insertion in lexicographic index order or through access pattern + // expansion in the innermost loop nest (`expValues` through `expCount`). OpOperand *sparseOut; unsigned outerParNest; - Value insChain; // bookkeeping for insertion chain + Value insChain; Value expValues; Value expFilled; Value expAdded; Value expCount; -private: // Bookkeeping for reductions (up-to-date value of the reduction, and indices // into the merger's expression tree. When the indices of a tensor reduction // expression are exhausted, all inner loops can use a scalarized reduction. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -12,27 +12,84 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// -// Code generation environment constructor and setup +// Code generation environment constructor and general methods //===----------------------------------------------------------------------===// CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, unsigned numTensors, unsigned numLoops, unsigned numFilterLoops) - : linalgOp(linop), options(opts), topSort(), - merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr), - sparseOut(nullptr), redVal(nullptr), redExp(-1u), redCustom(-1u) {} + : linalgOp(linop), sparseOptions(opts), + latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr), + topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), + expFilled(), expAdded(), expCount(), redVal(), redExp(-1u), + redCustom(-1u) {} -void CodegenEnv::startEmit(SparseTensorLoopEmitter *le) { - assert(!loopEmitter && "must only start emitting once"); +void CodegenEnv::startEmit(OpOperand *so, unsigned lv, + SparseTensorLoopEmitter *le) { + assert(sparseOut == nullptr && loopEmitter == nullptr && + insChain == nullptr && "must only start emitting once"); + sparseOut = so; + outerParNest = lv; loopEmitter = le; if (sparseOut) { insChain = sparseOut->get(); - merger.setHasSparseOut(true); + latticeMerger.setHasSparseOut(true); } } //===----------------------------------------------------------------------===// -// Code generation environment methods +// Code generation environment topological sort methods +//===----------------------------------------------------------------------===// + +ArrayRef CodegenEnv::getTopSortSlice(size_t n, size_t m) const { + return ArrayRef(topSort).slice(n, m); +} + +ArrayRef CodegenEnv::getLoopCurStack() const { + return getTopSortSlice(0, loopEmitter->getCurrentDepth()); +} + +Value CodegenEnv::getLoopIdxValue(size_t loopIdx) const { + for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++) + if (topSort[lv] == loopIdx) + return loopEmitter->getLoopIV(lv); + llvm_unreachable("invalid loop index"); +} + +//===----------------------------------------------------------------------===// +// Code generation environment sparse tensor output and expansion methods +//===----------------------------------------------------------------------===// + +void CodegenEnv::updateInsertionChain(Value chain) { + assert(sparseOut != nullptr && insChain != nullptr); + insChain = chain; +} + +bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const { + return sparseOut == o && outerParNest == rank - 1 && outerParNest == lv; +} + +void CodegenEnv::startExpand(Value values, Value filled, Value added, + Value count) { + assert(sparseOut != nullptr && expValues == nullptr); + expValues = values; + expFilled = filled; + expAdded = added; + expCount = count; +} + +void CodegenEnv::updateExpandCount(Value count) { + assert(sparseOut != nullptr && expValues != nullptr); + expCount = count; +} + +void CodegenEnv::endExpand() { + assert(sparseOut != nullptr && expValues != nullptr); + expValues = expFilled = expAdded = expCount = Value(); +} + +//===----------------------------------------------------------------------===// +// Code generation environment reduction methods //===----------------------------------------------------------------------===// void CodegenEnv::startReduc(unsigned exp, Value val) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -127,8 +127,8 @@ /// Helper method to construct a permuted dimension ordering /// that adheres to the given topological sort. static AffineMap permute(CodegenEnv &env, AffineMap m) { - assert(m.getNumDims() + env.merger.getNumFilterLoops() == - env.topSort.size() && + assert(m.getNumDims() + env.merger().getNumFilterLoops() == + env.topSortSize() && "size mismatch"); // Construct the inverse of `m`; to avoid the asymptotic complexity // of calling `m.getPermutedPosition` repeatedly. @@ -138,14 +138,14 @@ unsigned loopDepth = 1; // Construct the permutation. - while (worklist.any() && loopDepth <= env.topSort.size()) { + while (worklist.any() && loopDepth <= env.topSortSize()) { unsigned preSize = perm.size(); for (auto dim : worklist.set_bits()) { bool atLevel = false; if (m.getResult(dim).isa() || (isInvariantAffine(m.getResult(dim), env.getTopSortSlice(0, loopDepth), - env.topSort[loopDepth - 1], atLevel) && + env.topSortAt(loopDepth - 1), atLevel) && atLevel)) { // If the matching affine is constant expression or just become // invariant. We can visit the dimension now without breaking the @@ -163,7 +163,7 @@ } assert(perm.size() == numResults); - return AffineMap::getPermutationMap(perm, env.linalgOp.getContext()); + return AffineMap::getPermutationMap(perm, env.op().getContext()); } /// Helper method to inspect affine expressions. Rejects cases where the @@ -255,22 +255,22 @@ /// no annotations are found or inadmissible constructs occur. static bool findSparseAnnotations(CodegenEnv &env) { bool annotated = false; - unsigned filterLdx = env.merger.getFilterLoopStartingIdx(); - for (OpOperand &t : env.linalgOp->getOpOperands()) { - auto map = env.linalgOp.getMatchingIndexingMap(&t); + unsigned filterLdx = env.merger().getFilterLoopStartingIdx(); + for (OpOperand &t : env.op()->getOpOperands()) { + auto map = env.op().getMatchingIndexingMap(&t); auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; - assert(map.getNumResults() == env.linalgOp.getRank(&t)); + assert(map.getNumResults() == env.op().getRank(&t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t.getOperandNumber(); AffineExpr a = map.getResult(toOrigDim(enc, d)); - if (!findAffine(env.merger, tensor, d, a, getDimLevelType(enc, d), + if (!findAffine(env.merger(), tensor, d, a, getDimLevelType(enc, d), filterLdx)) return false; // inadmissible affine expression } } - assert(filterLdx == env.merger.getNumLoops()); + assert(filterLdx == env.merger().getNumLoops()); return annotated; } @@ -287,7 +287,7 @@ std::vector filterIt; // filter loop with 0 degree for (unsigned i = 0; i < n; i++) { if (inDegree[i] == 0) { - if (env.isFilterLoop(i)) + if (env.merger().isFilterLoop(i)) filterIt.push_back(i); else if (linalg::isReductionIterator(iteratorTypes[i])) redIt.push_back(i); @@ -318,12 +318,12 @@ // O(X) computation => O(NK+NMX) time complexity auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt); auto src = it.back(); - env.topSort.push_back(src); + env.topSortPushBack(src); it.pop_back(); // Update in-degree, and push 0-degree node into worklist. for (unsigned dst = 0; dst < n; dst++) { if (adjM[src][dst] && --inDegree[dst] == 0) { - if (env.isFilterLoop(dst)) + if (env.merger().isFilterLoop(dst)) filterIt.push_back(dst); else if (linalg::isReductionIterator(iteratorTypes[dst])) redIt.push_back(dst); @@ -332,7 +332,7 @@ } } } - return env.topSort.size() == n; + return env.topSortSize() == n; } /// Helper method to add all constraints from the indices in one affine @@ -428,17 +428,16 @@ OpOperand *skip = nullptr) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. - unsigned n = env.merger.getNumLoops(); + unsigned n = env.merger().getNumLoops(); std::vector> adjM(n, std::vector(n, false)); std::vector inDegree(n, 0); // in-degree of each node. - auto iteratorTypes = env.linalgOp.getIteratorTypesArray(); + auto iteratorTypes = env.op().getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. - for (OpOperand &t : env.linalgOp->getOpOperands()) { + for (OpOperand &t : env.op()->getOpOperands()) { // Get map and encoding. - auto map = env.linalgOp.getMatchingIndexingMap(&t); + auto map = env.op().getMatchingIndexingMap(&t); auto enc = getSparseTensorEncoding(t.get().getType()); - assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.linalgOp) == - n); + assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n); // Skip dense tensor constraints when not requested. if (!(mask & SortMask::kIncludeDense) && !enc) continue; @@ -448,11 +447,12 @@ // on the loop indices if no explicit dimension ordering is given. for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr ta = map.getResult(toOrigDim(enc, d)); - Optional tldx = env.merger.getLoopIdx(t.getOperandNumber(), d); + Optional tldx = + env.merger().getLoopIdx(t.getOperandNumber(), d); // Filter loops should be constructed after all the dependent loops, // i.e., d0 + d1 < filter_loop(d0 + d1) - if (tldx && env.isFilterLoop(*tldx)) { + if (tldx && env.merger().isFilterLoop(*tldx)) { assert(!ta.isa() && !isDenseDLT(getDimLevelType(enc, d))); addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, @@ -472,7 +472,7 @@ if (d > 0) { AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); Optional fldx = - env.merger.getLoopIdx(t.getOperandNumber(), d - 1); + env.merger().getLoopIdx(t.getOperandNumber(), d - 1); // Applying order constraints on every pair of dimExpr between two // compound affine expressions can sometime too strict: @@ -480,7 +480,7 @@ // It is totally fine to have loop sequence d0->d2->d1->d3 instead of // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. if (!(mask & SortMask::kIncludeDense)) - tryLoosenAffineDenseConstraints(env.linalgOp, fldx, fa, tldx, ta); + tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta); // (d0 + d1) < (d2 + d3), or // filter_loop_d-1 < (d2 + d3), or @@ -495,23 +495,22 @@ if (mask & SortMask::kIncludeUndef) { unsigned tensor = t.getOperandNumber(); for (unsigned i = 0; i < n; i++) - if (isCompressedDLT(env.dimLevelType(tensor, i)) || - isSingletonDLT(env.dimLevelType(tensor, i))) { + if (isCompressedDLT(env.dlt(tensor, i)) || + isSingletonDLT(env.dlt(tensor, i))) { for (unsigned j = 0; j < n; j++) - if (isUndefDLT(env.dimLevelType(tensor, j))) { + if (isUndefDLT(env.dlt(tensor, j))) { adjM[i][j] = true; inDegree[j]++; } } else { - assert(isDenseDLT(env.dimLevelType(tensor, i)) || - isUndefDLT(env.dimLevelType(tensor, i))); + assert(isDenseDLT(env.dlt(tensor, i)) || + isUndefDLT(env.dlt(tensor, i))); } } } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. - env.topSort.clear(); - env.topSort.reserve(n); + env.topSortClear(n); return topSortOptimal(env, n, iteratorTypes, inDegree, adjM); } @@ -526,20 +525,22 @@ /// whether the out tensor in the tensor expression codegen is admissible. /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective /// nesting depth when a "truly dynamic" sparse tensor output occurs. -static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) { +static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp, + OpOperand **sparseOut, + unsigned *outerParNest) { // We reject any expression that makes a reduction from `-outTensor`, as those // expressions create a dependency between the current iteration (i) and the // previous iteration (i-1). It would require iterating over the whole // coordinate space, which prevent exploiting sparsity for faster code. - for (utils::IteratorType it : env.linalgOp.getIteratorTypesArray()) { + for (utils::IteratorType it : env.op().getIteratorTypesArray()) { if (it == utils::IteratorType::reduction) { - if (env.merger.hasNegateOnOut(exp)) + if (env.merger().hasNegateOnOut(exp)) return false; break; } } - OpOperand *lhs = env.linalgOp.getDpsInitOperand(0); + OpOperand *lhs = env.op().getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); // An non-annotated output tensor is assumed dense, and becomes a random @@ -550,40 +551,39 @@ // access 1-dim memref. Also admissible since insertions cannot occur. bool allDense = true; unsigned numLoops = - env.merger.getNumLoops(); // numNativeLoops + numFilterLoops - for (unsigned i = 0; i < env.merger.getNumLoops(); i++) - if (isCompressedDLT(env.dimLevelType(tensor, i)) || - isSingletonDLT(env.dimLevelType(tensor, i))) { + env.merger().getNumLoops(); // numNativeLoops + numFilterLoops + for (unsigned i = 0; i < env.merger().getNumLoops(); i++) + if (isCompressedDLT(env.dlt(tensor, i)) || + isSingletonDLT(env.dlt(tensor, i))) { allDense = false; break; } else { - assert(isDenseDLT(env.dimLevelType(tensor, i)) || - isUndefDLT(env.dimLevelType(tensor, i))); + assert(isDenseDLT(env.dlt(tensor, i)) || isUndefDLT(env.dlt(tensor, i))); } if (allDense) return true; // TODO: support compound affine expression on sparse output. - if (getNumCompoundAffineOnSparseDims(env.linalgOp.getMatchingIndexingMap(lhs), + if (getNumCompoundAffineOnSparseDims(env.op().getMatchingIndexingMap(lhs), lhs->get()) != 0) return false; // A tensor expression with a sparse output tensor that changes its values // but not its nonzero structure, an operation called "simply dynamic" in // [Bik96,Ch9], is also admissible without special env. - if (env.merger.isSingleCondition(tensor, exp)) + if (env.merger().isSingleCondition(tensor, exp)) return true; // Accept "truly dynamic" if the output tensor materializes uninitialized // into the computation and insertions occur in lexicographic index order. if (isMaterializing(lhs->get())) { - auto iteratorTypes = env.linalgOp.getIteratorTypesArray(); + auto iteratorTypes = env.op().getIteratorTypesArray(); unsigned nest = 0; for (unsigned i = 0; i < numLoops; i++) { - if (!env.isFilterLoop(env.topSort[i])) { + if (!env.merger().isFilterLoop(env.topSortAt(i))) { // We only count non-filter loops as filter loops should be considered // as a special type of parallel loops. - if (linalg::isReductionIterator(iteratorTypes[env.topSort[i]])) + if (linalg::isReductionIterator(iteratorTypes[env.topSortAt(i)])) break; // terminate at first reduction nest++; } @@ -591,9 +591,9 @@ // Determine admissible dynamic insertion situations: // (1) fully injective, since there are no reductions, // (2) admissible 1-d expansion in innermost dimension. - if (nest >= env.linalgOp.getRank(lhs) - 1) { - env.sparseOut = lhs; - env.outerParNest = nest; + if (nest >= env.op().getRank(lhs) - 1) { + *sparseOut = lhs; + *outerParNest = nest; return true; } } @@ -613,10 +613,10 @@ SmallVector reduc; if (env.isReduc()) reduc.push_back(env.getReduc()); - if (env.expValues) - reduc.push_back(env.expCount); - if (env.insChain) - reduc.push_back(env.insChain); + if (env.isExpand()) + reduc.push_back(env.getExpandCount()); + if (env.getInsertionChain()) + reduc.push_back(env.getInsertionChain()); auto r = callback(reduc); @@ -624,21 +624,21 @@ unsigned i = 0; if (env.isReduc()) env.updateReduc(reduc[i++]); - if (env.expValues) - env.expCount = reduc[i++]; - if (env.insChain) - env.insChain = reduc[i]; + if (env.isExpand()) + env.updateExpandCount(reduc[i++]); + if (env.getInsertionChain()) + env.updateInsertionChain(reduc[i]); return r; } /// Local bufferization of all dense and sparse data structures. static void genBuffers(CodegenEnv &env, OpBuilder &builder) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); Location loc = op.getLoc(); assert(op.getNumOperands() == op.getNumDpsInputs() + 1); - env.loopEmitter->initializeLoopEmit( + env.emitter()->initializeLoopEmit( builder, loc, /// Generates buffer for the output tensor. /// Note that all sparse kernels assume that when all elements are written @@ -676,7 +676,7 @@ /// Generates index for load/store on sparse tensor. static Value genIndex(CodegenEnv &env, OpOperand *t) { - auto map = env.linalgOp.getMatchingIndexingMap(t); + auto map = env.op().getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); assert(a.getKind() == AffineExprKind::DimId); @@ -687,69 +687,73 @@ /// Generates subscript for load/store on a dense or sparse tensor. static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl &args) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); unsigned tensor = t->getOperandNumber(); auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); unsigned rank = map.getNumResults(); if (enc) { - Value pidx = env.loopEmitter->getPidxs()[tensor].back(); + Value pidx = env.emitter()->getPidxs()[tensor].back(); assert(pidx); args.push_back(pidx); // position index } else { for (unsigned d = 0; d < rank; d++) { AffineExpr a = map.getResult(d); - args.push_back(env.loopEmitter->genAffine(builder, a, op.getLoc())); + args.push_back(env.emitter()->genAffine(builder, a, op.getLoc())); } } - return env.getValBuffer()[tensor]; + return env.emitter()->getValBuffer()[tensor]; } /// Generates insertion code to implement dynamic tensor load. static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); Location loc = op.getLoc(); // Direct lexicographic index order, tensor loads as zero. - if (!env.expValues) { + if (!env.isExpand()) { Type tp = getElementTypeOrSelf(t->get().getType()); return constantZero(builder, loc, tp); } // Load from expanded access pattern. Value index = genIndex(env, t); - return builder.create(loc, env.expValues, index); + return builder.create(loc, env.getExpandValues(), index); } /// Generates insertion code to implement dynamic tensor load for reduction. static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); Location loc = op.getLoc(); Value identity = env.getCustomRedId(); // Direct lexicographic index order, tensor loads as identity. - if (!env.expValues) + if (!env.isExpand()) return identity; // Load from expanded access pattern if filled, identity otherwise. + Value values = env.getExpandValues(); + Value filled = env.getExpandFilled(); Value index = genIndex(env, t); - Value isFilled = builder.create(loc, env.expFilled, index); - Value valAtIndex = builder.create(loc, env.expValues, index); + Value isFilled = builder.create(loc, filled, index); + Value valAtIndex = builder.create(loc, values, index); return builder.create(loc, isFilled, valAtIndex, identity); } /// Generates insertion code to implement dynamic tensor store. static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); Location loc = op.getLoc(); // Direct insertion in lexicographic index order. - if (!env.expValues) { + if (!env.isExpand()) { unsigned rank = op.getRank(t); SmallVector indices; for (unsigned i = 0; i < rank; i++) { - assert(env.getLoopIV(i)); - indices.push_back(env.getLoopIV(i)); + assert(env.emitter()->getLoopIV(i)); + indices.push_back(env.emitter()->getLoopIV(i)); } - env.insChain = builder.create(loc, rhs, env.insChain, indices); + Value chain = env.getInsertionChain(); + env.updateInsertionChain( + builder.create(loc, rhs, chain, indices)); return; } // Generates insertion code along expanded access pattern. @@ -758,29 +762,33 @@ // expAdded[inserts++] = i // endif // values[i] = rhs + Value values = env.getExpandValues(); + Value filled = env.getExpandFilled(); + Value added = env.getExpandAdded(); + Value count = env.getExpandCount(); Value index = genIndex(env, t); Value fval = constantI1(builder, loc, false); Value tval = constantI1(builder, loc, true); // If statement. - Value filled = builder.create(loc, env.expFilled, index); + Value isFilled = builder.create(loc, filled, index); Value cond = builder.create(loc, arith::CmpIPredicate::eq, - filled, fval); + isFilled, fval); scf::IfOp ifOp = builder.create(loc, builder.getIndexType(), cond, /*else=*/true); // True branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - builder.create(loc, tval, env.expFilled, index); - builder.create(loc, index, env.expAdded, env.expCount); + builder.create(loc, tval, filled, index); + builder.create(loc, index, added, count); Value one = constantIndex(builder, loc, 1); - Value add = builder.create(loc, env.expCount, one); + Value add = builder.create(loc, count, one); builder.create(loc, add); // False branch. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, env.expCount); + builder.create(loc, count); builder.setInsertionPointAfter(ifOp); // Value assignment. - env.expCount = ifOp.getResult(0); - builder.create(loc, rhs, env.expValues, index); + env.updateExpandCount(ifOp.getResult(0)); + builder.create(loc, rhs, values, index); } /// Generates a load on a dense or sparse tensor. @@ -791,23 +799,23 @@ return val; // Load during insertion. - linalg::GenericOp op = env.linalgOp; - OpOperand &t = op->getOpOperand(env.exp(exp).tensor); - if (&t == env.sparseOut) { + linalg::GenericOp op = env.op(); + OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); + if (env.isSparseOutput(t)) { if (env.isCustomReduc()) - return genInsertionLoadReduce(env, builder, &t); - return genInsertionLoad(env, builder, &t); + return genInsertionLoadReduce(env, builder, t); + return genInsertionLoad(env, builder, t); } // Actual load. SmallVector args; - Value ptr = genSubscript(env, builder, &t, args); + Value ptr = genSubscript(env, builder, t, args); return builder.create(op.getLoc(), ptr, args); } /// Generates a store on a dense or sparse tensor. static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp, Value rhs) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); Location loc = op.getLoc(); // Test if this is a scalarized reduction. if (env.isReduc()) { @@ -816,17 +824,16 @@ } // Store during insertion. OpOperand *t = op.getDpsInitOperand(0); - if (t == env.sparseOut) { + if (env.isSparseOutput(t)) { if (!rhs) { // Only unary and binary are allowed to return uninitialized rhs // to indicate missing output. assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary); } else if (env.exp(exp).kind == kSelect) { // Select operation insertion. - Value insChain = env.insChain; - assert(insChain); - scf::IfOp ifOp = builder.create(loc, insChain.getType(), rhs, - /*else=*/true); + Value chain = env.getInsertionChain(); + scf::IfOp ifOp = + builder.create(loc, chain.getType(), rhs, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Existing value was preserved to be used here. assert(env.exp(exp).val); @@ -834,12 +841,13 @@ genInsertionStore(env, builder, t, v0); env.exp(exp).val = Value(); // Yield modified insertion chain along true branch. - builder.create(op.getLoc(), env.insChain); + Value mchain = env.getInsertionChain(); + builder.create(op.getLoc(), mchain); // Yield original insertion chain along false branch. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, insChain); + builder.create(loc, chain); // Done with if statement. - env.insChain = ifOp->getResult(0); + env.updateInsertionChain(ifOp->getResult(0)); builder.setInsertionPointAfter(ifOp); } else { genInsertionStore(env, builder, t, rhs); @@ -884,7 +892,7 @@ /// Recursively generates tensor expression. static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp, unsigned ldx) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); Location loc = op.getLoc(); if (exp == -1u) @@ -901,7 +909,7 @@ Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx); Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx); - Value ee = env.merger.buildExp(rewriter, loc, exp, v0, v1); + Value ee = env.merger().buildExp(rewriter, loc, exp, v0, v1); if (ee && (env.exp(exp).kind == Kind::kUnary || env.exp(exp).kind == Kind::kBinary || env.exp(exp).kind == Kind::kBinaryBranch || @@ -928,14 +936,15 @@ if (env.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); auto map = op.getMatchingIndexingMap(&t); auto enc = getSparseTensorEncoding(t.get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(toOrigDim(enc, d)); - Optional sldx = env.merger.getLoopIdx(t.getOperandNumber(), d); - if (sldx && env.isFilterLoop(*sldx)) { + Optional sldx = + env.merger().getLoopIdx(t.getOperandNumber(), d); + if (sldx && env.merger().isFilterLoop(*sldx)) { if (!env.getLoopIdxValue(*sldx)) // The filter loops has not been constructed. return; @@ -978,11 +987,11 @@ } /// Generates an expanded access pattern in innermost dimension. -static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at, - bool atStart) { - linalg::GenericOp op = env.linalgOp; - OpOperand *lhs = env.sparseOut; - if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest) +static void genExpand(CodegenEnv &env, OpBuilder &builder, unsigned at, + bool atStart) { + linalg::GenericOp op = env.op(); + OpOperand *lhs = op.getDpsInitOperand(0); + if (!env.atExpandLevel(lhs, op.getRank(lhs), at)) return; // not needed at this level assert(!env.isReduc()); // Generate start or end of an expanded access pattern. Note that because @@ -999,25 +1008,23 @@ Type t2 = MemRefType::get(dynShape, builder.getI1Type()); Type t3 = MemRefType::get(dynShape, builder.getIndexType()); Type t4 = builder.getIndexType(); - auto res = - builder.create(loc, TypeRange({t1, t2, t3, t4}), tensor); - assert(res.getNumResults() == 4); - assert(!env.expValues); - env.expValues = res.getResult(0); - env.expFilled = res.getResult(1); - env.expAdded = res.getResult(2); - env.expCount = res.getResult(3); + auto r = builder.create(loc, TypeRange({t1, t2, t3, t4}), tensor); + assert(r.getNumResults() == 4); + env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2), + r.getResult(3)); } else { - assert(env.expValues); SmallVector indices; - for (unsigned i = 0; i < at; i++) { - assert(env.getLoopIV(i)); - indices.push_back(env.getLoopIV(i)); - } - env.insChain = builder.create(loc, env.expValues, env.expFilled, - env.expAdded, env.expCount, - env.insChain, indices); - env.expValues = env.expFilled = env.expAdded = env.expCount = Value(); + for (unsigned i = 0; i < at; i++) + indices.push_back(env.emitter()->getLoopIV(i)); + Value values = env.getExpandValues(); + Value filled = env.getExpandFilled(); + Value added = env.getExpandAdded(); + Value count = env.getExpandCount(); + Value chain = env.getInsertionChain(); + Value compress = builder.create(loc, values, filled, added, + count, chain, indices); + env.updateInsertionChain(compress); + env.endExpand(); } } @@ -1026,13 +1033,13 @@ /// converted to a parallel operation depends on the requested strategy. static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { // Reject parallelization of sparse output. - if (env.sparseOut) + if (env.hasSparseOutput()) return false; // Parallel loops on tensor expansion can cause data races. - if (env.expCount) + if (env.isExpand()) return false; // Inspect strategy. - switch (env.options.parallelizationStrategy) { + switch (env.options().parallelizationStrategy) { case SparseParallelizationStrategy::kNone: return false; case SparseParallelizationStrategy::kDenseOuterLoop: @@ -1052,15 +1059,15 @@ bool isInner, unsigned idx, size_t tid, size_t dim, ArrayRef extraTids, ArrayRef extraDims) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); Location loc = op.getLoc(); auto iteratorTypes = op.getIteratorTypesArray(); - bool isSparse = isCompressedDLT(env.dimLevelType(tid, idx)) || - isSingletonDLT(env.dimLevelType(tid, idx)); + bool isSparse = + isCompressedDLT(env.dlt(tid, idx)) || isSingletonDLT(env.dlt(tid, idx)); bool isParallel = isParallelFor(env, isOuter, isSparse); Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef reduc) { - if (env.isFilterLoop(idx)) { + if (env.merger().isFilterLoop(idx)) { // extraTids/extraDims must be empty because filter loops only // corresponding to the one and only sparse tensor level. assert(isSparse && extraTids.empty() && extraDims.empty()); @@ -1069,10 +1076,10 @@ // Retrieves the affine expression for the filter loop. AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim)); - return env.loopEmitter->enterFilterLoopOverTensorAtDim(builder, loc, tid, - dim, a, reduc); + return env.emitter()->enterFilterLoopOverTensorAtDim(builder, loc, tid, + dim, a, reduc); } - return env.loopEmitter->enterLoopOverTensorAtDim( + return env.emitter()->enterLoopOverTensorAtDim( builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims); }); assert(loop); @@ -1088,8 +1095,8 @@ Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef reduc) { // Construct the while-loop with a parameter for each // index. - return env.loopEmitter->enterCoIterationOverTensorsAtDims( - builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc, + return env.emitter()->enterCoIterationOverTensorsAtDims( + builder, env.op().getLoc(), condTids, condDims, needsUniv, reduc, extraTids, extraDims); }); assert(loop); @@ -1104,10 +1111,10 @@ ArrayRef extraDims) { assert(condTids.size() == condDims.size()); assert(extraTids.size() == extraDims.size()); - unsigned idx = env.topSort[at]; + unsigned idx = env.topSortAt(at); if (condTids.size() == 1) { bool isOuter = at == 0; - bool isInner = at == env.topSort.size() - 1; + bool isInner = at == env.topSortSize() - 1; return genFor(env, builder, isOuter, isInner, idx, condTids.front(), condDims.front(), extraTids, extraDims); } @@ -1119,9 +1126,9 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx, bool needsUniv, BitVector &induction, scf::WhileOp whileOp) { - Location loc = env.linalgOp.getLoc(); + Location loc = env.op().getLoc(); // Finalize each else branch of all if statements. - if (env.isReduc() || env.expValues || env.insChain) { + if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { while (auto ifOp = dyn_cast_or_null( builder.getInsertionBlock()->getParentOp())) { unsigned y = 0; @@ -1130,13 +1137,13 @@ yields.push_back(env.getReduc()); env.updateReduc(ifOp.getResult(y++)); } - if (env.expValues) { - yields.push_back(env.expCount); - env.expCount = ifOp->getResult(y++); + if (env.isExpand()) { + yields.push_back(env.getExpandCount()); + env.updateExpandCount(ifOp->getResult(y++)); } - if (env.insChain) { - yields.push_back(env.insChain); - env.insChain = ifOp->getResult(y++); + if (env.getInsertionChain()) { + yields.push_back(env.getInsertionChain()); + env.updateInsertionChain(ifOp->getResult(y++)); } assert(y == yields.size()); builder.create(loc, yields); @@ -1149,35 +1156,34 @@ /// Generates a single if-statement within a while-loop. static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx, BitVector &conditions) { - Location loc = env.linalgOp.getLoc(); + Location loc = env.op().getLoc(); SmallVector types; Value cond; for (unsigned b = 0, be = conditions.size(); b < be; b++) { if (!conditions[b]) continue; - unsigned tensor = env.merger.tensor(b); - assert(idx == env.merger.index(b)); + unsigned tensor = env.merger().tensor(b); + assert(idx == env.merger().index(b)); Value clause; - if (isCompressedDLT(env.dimLevelType(b)) || - isSingletonDLT(env.dimLevelType(b))) { - auto dim = *env.merger.getDimNum(tensor, idx); - Value op1 = env.loopEmitter->getCoord()[tensor][dim]; + if (isCompressedDLT(env.dlt(b)) || isSingletonDLT(env.dlt(b))) { + auto dim = *env.merger().getDimNum(tensor, idx); + Value op1 = env.emitter()->getCoord()[tensor][dim]; Value op2 = env.getLoopIdxValue(idx); clause = builder.create(loc, arith::CmpIPredicate::eq, op1, op2); } else { - assert(isDenseDLT(env.dimLevelType(b)) || - isUndefDLT(env.dimLevelType(b))); + assert(isDenseDLT(env.merger().getDimLevelType(b)) || + isUndefDLT(env.merger().getDimLevelType(b))); clause = constantI1(builder, loc, true); } cond = cond ? builder.create(loc, cond, clause) : clause; } if (env.isReduc()) types.push_back(env.getReduc().getType()); - if (env.expValues) + if (env.isExpand()) types.push_back(builder.getIndexType()); - if (env.insChain) - types.push_back(env.insChain.getType()); + if (env.getInsertionChain()) + types.push_back(env.getInsertionChain().getType()); scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); return ifOp; @@ -1192,16 +1198,16 @@ operands.push_back(env.getReduc()); env.updateReduc(redInput); } - if (env.expValues) { - operands.push_back(env.expCount); - env.expCount = cntInput; + if (env.isExpand()) { + operands.push_back(env.getExpandCount()); + env.updateExpandCount(cntInput); } - if (env.insChain) { - operands.push_back(env.insChain); - env.insChain = insInput; + if (env.getInsertionChain()) { + operands.push_back(env.getInsertionChain()); + env.updateInsertionChain(insInput); } if (!operands.empty()) - builder.create(env.linalgOp.getLoc(), operands); + builder.create(env.op().getLoc(), operands); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); } @@ -1218,17 +1224,17 @@ // Emit invariants at this loop sequence level. genInvariants(env, builder, exp, ldx, /*atStart=*/true); // Emit access pattern expansion for sparse tensor output. - genExpansion(env, builder, at, /*atStart=*/true); + genExpand(env, builder, at, /*atStart=*/true); // Emit further intitialization at this loop sequence level. unsigned l0 = env.set(lts)[0]; bool needsUniv = false; SmallVector tids; SmallVector dims; - env.merger.foreachTidDimPairInBits( + env.merger().foreachTidDimPairInBits( env.lat(l0).bits, [&](unsigned b, unsigned tid, Optional dim, DimLevelType dlt) { - assert(env.merger.index(b) == idx); + assert(env.merger().index(b) == idx); if (isDenseDLT(dlt) || isUndefDLT(dlt)) { needsUniv = true; } else { @@ -1238,7 +1244,7 @@ } }); - env.loopEmitter->enterNewLoopSeq(builder, env.linalgOp.getLoc(), tids, dims); + env.emitter()->enterNewLoopSeq(builder, env.op().getLoc(), tids, dims); // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. @@ -1246,7 +1252,7 @@ unsigned lsize = env.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { unsigned li = env.set(lts)[i]; - if (!env.merger.hasAnySparse(env.lat(li).simple)) + if (!env.merger().hasAnySparse(env.lat(li).simple)) return true; } } @@ -1257,7 +1263,7 @@ OpBuilder &builder, unsigned tid, unsigned lvl) { // TODO: Handle affine expression on output tensor. - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); assert(tid < op.getNumDpsInputs()); OpOperand *input = op.getDpsInputOperands()[tid]; ArrayRef affines = op.getMatchingIndexingMap(input).getResults(); @@ -1267,7 +1273,7 @@ AffineExpr affine = affines[toOrigDim(enc, i)]; if (isDenseDLT(getDimLevelType(enc, i)) && affine.isa()) - env.loopEmitter->genDenseAffineAddressAtCurLevel( + env.emitter()->genDenseAffineAddressAtCurLevel( builder, op.getLoc(), input->getOperandNumber(), i, affine); else return; // break on first non-dense non-constant level @@ -1281,7 +1287,7 @@ // starting from the first level as they do not depend on any thing. // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two // levels can be determined before loops. - for (unsigned tid = 0, e = env.linalgOp.getNumDpsInputs(); tid < e; tid++) + for (unsigned tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) genConstantDenseAddressFromLevel(env, rewriter, tid, 0); } @@ -1295,9 +1301,9 @@ const BitVector &simple = env.lat(li).simple; // Converts bits to array + dim pair - env.merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid, - Optional dim, - DimLevelType dlt) { + env.merger().foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid, + Optional dim, + DimLevelType dlt) { if (simple.test(b)) { if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to iterate based @@ -1306,8 +1312,8 @@ // output tensor). // out[i][j] = invariant; or a broadcast // out[i][j] = in[i] (j is undef for input) - tid = env.merger.getOutTensorID(); - dim = env.merger.getDimNum(tid, idx); + tid = env.merger().getOutTensorID(); + dim = env.merger().getDimNum(tid, idx); // Skips invalid dim (e.g., when this is a zero ranked tensor). if (!dim) return; @@ -1320,7 +1326,7 @@ extraDims.push_back(*dim); } else { assert(isUndefDLT(dlt)); - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); if (tid >= op.getNumDpsInputs()) // We only handle affine expression on input tensors (for now). return; @@ -1361,12 +1367,12 @@ } }); - if (isDenseDLT(env.dimLevelType(env.merger.getOutTensorID(), idx))) { + if (isDenseDLT(env.dlt(env.merger().getOutTensorID(), idx))) { // Note that we generate dense indices of the output tensor // unconditionally, since they may not appear in the lattice, but may be - // needed for linearized codegen. - auto dim = *env.merger.getDimNum(env.merger.getOutTensorID(), idx); - extraTids.push_back(env.merger.getOutTensorID()); + // needed for linearized env. + auto dim = *env.merger().getDimNum(env.merger().getOutTensorID(), idx); + extraTids.push_back(env.merger().getOutTensorID()); extraDims.push_back(dim); } } @@ -1384,7 +1390,7 @@ // level. SmallVector affineTids, affineDims; SmallVector affines; - translateBitsToTidDimPairs(env, li, env.topSort[at], condTids, condDims, + translateBitsToTidDimPairs(env, li, env.topSortAt(at), condTids, condDims, extraTids, extraDims, affineTids, affineDims, affines); @@ -1392,8 +1398,8 @@ Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims, extraTids, extraDims); for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { - env.loopEmitter->genDenseAffineAddressAtCurLevel( - builder, env.linalgOp.getLoc(), tid, dim, exp); + env.emitter()->genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(), + tid, dim, exp); } // Until now, we have entered every pair in {cond, extra, @@ -1402,7 +1408,7 @@ auto allTids = llvm::concat(condTids, extraTids, affineTids); auto allDims = llvm::concat(condDims, extraDims, affineDims); for (auto [tid, dim] : llvm::zip(allTids, allDims)) { - if (tid != env.merger.getOutTensorID()) + if (tid != env.merger().getOutTensorID()) genConstantDenseAddressFromLevel(env, builder, tid, dim + 1); } @@ -1420,7 +1426,7 @@ } genLoopBoundary(env, [&](MutableArrayRef reduc) { - env.loopEmitter->exitCurrentLoop(rewriter, env.linalgOp.getLoc(), reduc); + env.emitter()->exitCurrentLoop(rewriter, env.op().getLoc(), reduc); return std::nullopt; }); @@ -1431,11 +1437,11 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at, unsigned idx, unsigned ldx) { assert(env.getLoopIdxValue(idx) == nullptr); - env.loopEmitter->exitCurrentLoopSeq(); + env.emitter()->exitCurrentLoopSeq(); // Unmark bookkeeping of invariants and loop index. genInvariants(env, builder, exp, ldx, /*atStart=*/false); // Finalize access pattern expansion for sparse tensor output. - genExpansion(env, builder, at, /*atStart=*/false); + genExpand(env, builder, at, /*atStart=*/false); } /// Recursively generates code while computing iteration lattices in order @@ -1444,17 +1450,17 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp, unsigned at) { // At each leaf, assign remaining tensor (sub)expression to output tensor. - if (at == env.topSort.size()) { - unsigned ldx = env.topSort[at - 1]; + if (at == env.topSortSize()) { + unsigned ldx = env.topSortAt(at - 1); Value rhs = genExp(env, rewriter, exp, ldx); genTensorStore(env, rewriter, exp, rhs); return; } // Construct iteration lattices for current loop index, with L0 at top. - unsigned idx = env.topSort[at]; - unsigned ldx = at == 0 ? -1u : env.topSort[at - 1]; - unsigned lts = env.merger.optimizeSet(env.merger.buildLattices(exp, idx)); + unsigned idx = env.topSortAt(at); + unsigned ldx = at == 0 ? -1u : env.topSortAt(at - 1); + unsigned lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx)); // TODO: sort // TODO: dedup @@ -1472,13 +1478,13 @@ // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. Value redInput = env.getReduc(); - Value cntInput = env.expCount; - Value insInput = env.insChain; + Value cntInput = env.getExpandCount(); + Value insInput = env.getInsertionChain(); bool isWhile = dyn_cast(loop) != nullptr; for (unsigned j = 0; j < lsize; j++) { unsigned lj = env.set(lts)[j]; unsigned ej = env.lat(lj).exp; - if (li == lj || env.merger.latGT(li, lj)) { + if (li == lj || env.merger().latGT(li, lj)) { // Recurse into body of each branch. if (isWhile) { scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple); @@ -1500,7 +1506,7 @@ /// Converts the result computed by the sparse kernel into the required form. static void genResult(CodegenEnv &env, RewriterBase &rewriter) { - linalg::GenericOp op = env.linalgOp; + linalg::GenericOp op = env.op(); OpOperand *lhs = op.getDpsInitOperand(0); Value tensor = lhs->get(); Type resType = tensor.getType(); @@ -1508,14 +1514,16 @@ // The sparse tensor rematerializes from the original sparse tensor's // underlying sparse storage format. For an insertion chain, the // tensor materializes from the chain with 'hasInserts' enabled. - bool hasInserts = env.sparseOut == lhs; - if (hasInserts) - tensor = env.insChain; + bool hasInserts = false; + if (Value chain = env.getInsertionChain()) { + hasInserts = true; + tensor = chain; + } rewriter.replaceOpWithNewOp(op, resType, tensor, hasInserts); } else { // To rematerialize an non-annotated tensor, simply load it // from the bufferized value. - Value val = env.getValBuffer().back(); // value array + Value val = env.emitter()->getValBuffer().back(); // value array rewriter.replaceOpWithNewOp(op, resType, val); } } @@ -1550,8 +1558,8 @@ return failure(); // Builds the tensor expression for the Linalg operation in SSA form. - Optional optExp = env.merger.buildTensorExpFromLinalg(op); - if (!optExp.has_value()) + Optional optExp = env.merger().buildTensorExpFromLinalg(op); + if (!optExp) return failure(); unsigned exp = *optExp; @@ -1562,6 +1570,8 @@ // to resolve cycles by inserting a conversion. bool isAdmissible = false; bool hasCycle = true; + OpOperand *sparseOut = nullptr; + unsigned outerParNest = -1u; // An const list of all masks that we used for interation graph // computation. Must be ordered from more strict to less strict. const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, @@ -1569,7 +1579,7 @@ for (auto mask : allMask) if (computeIterationGraph(env, mask)) { hasCycle = false; - if (isAdmissibleTensorExp(env, exp)) { + if (isAdmissibleTensorExp(env, exp, &sparseOut, &outerParNest)) { isAdmissible = true; break; } @@ -1589,9 +1599,9 @@ SparseTensorLoopEmitter lpe( tensors, StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()), - /*hasOutput=*/true, /*isSparseOut=*/env.sparseOut != nullptr, - env.topSort); - env.startEmit(&lpe); + /*hasOutput=*/true, /*isSparseOut=*/sparseOut != nullptr, + env.topSortRef()); + env.startEmit(sparseOut, outerParNest, &lpe); // Recursively generates code if admissible. genBuffers(env, rewriter); @@ -1607,7 +1617,7 @@ // Compute topological sort while leaving out every // sparse input tensor in succession until an acylic // iteration graph results. - for (OpOperand *t : env.linalgOp.getDpsInputOperands()) { + for (OpOperand *t : env.op().getDpsInputOperands()) { unsigned tensor = t->getOperandNumber(); Value tval = t->get(); auto srcEnc = getSparseTensorEncoding(tval.getType()); @@ -1623,14 +1633,14 @@ auto srcTp = tval.getType().cast(); auto dstEnc = SparseTensorEncodingAttr::get( getContext(), srcEnc.getDimLevelType(), - permute(env, env.linalgOp.getMatchingIndexingMap(t)), // new order + permute(env, env.op().getMatchingIndexingMap(t)), // new order srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); auto dstTp = RankedTensorType::get(srcTp.getShape(), srcTp.getElementType(), dstEnc); auto convert = rewriter.create(tval.getLoc(), dstTp, tval); - env.linalgOp->setOperand(tensor, convert); - rewriter.setInsertionPointAfter(env.linalgOp); + env.op()->setOperand(tensor, convert); + rewriter.setInsertionPointAfter(env.op()); rewriter.create(tval.getLoc(), convert); return success(); }