diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -41,7 +41,7 @@ namespace { -// Iteration graph sorting. +/// Iteration graph sorting. enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, @@ -49,37 +49,91 @@ kIncludeAll = 0x3 }; -// Reduction kinds. +/// Reduction kinds. enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom }; -// Code generation. -struct CodeGen { - CodeGen(SparsificationOptions o, MLIRContext *context, ValueRange tensors, - unsigned numTensors, unsigned numLoops, OpOperand *op, unsigned nest, - std::vector &ts) - : options(o), - loopEmitter( - tensors, - StringAttr::get(context, linalg::GenericOp::getOperationName()), - /*hasOutput=*/true, - /*isSparseOut=*/op != nullptr, ts), - sparseOut(op), outerParNest(nest), topSort(ts) { - if (op) - insChain = op->get(); +/// Code generation environment. This structure aggregates a number +/// of data structures needed during code generation. Such an environment +/// simplifies passing around data during sparsification (rather than +/// passing around all the individual compoments where needed). +// +// TODO: refactor further, move into own file +// +struct CodeGenEnv { + CodeGenEnv(linalg::GenericOp linop, SparsificationOptions opts, + unsigned numTensors, unsigned numLoops, unsigned numFilterLoops) + : linalgOp(linop), options(opts), topSort(), + merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr), + redExp(-1u), redKind(kNoReduc), redCustom(-1u), sparseOut(nullptr) {} + + // Start emitting. + void startEmit(SparseTensorLoopEmitter *le) { + assert(!loopEmitter && "must only start emitting once"); + loopEmitter = le; + if (sparseOut) { + insChain = sparseOut->get(); + merger.setHasSparseOut(true); + } + } + + // Delegate methods to merger. + TensorExp &exp(unsigned e) { return merger.exp(e); } + LatPoint &lat(unsigned l) { return merger.lat(l); } + SmallVector &set(unsigned s) { return merger.set(s); } + DimLevelType dimLevelType(unsigned t, unsigned i) const { + return merger.getDimLevelType(t, i); + } + DimLevelType dimLevelType(unsigned b) const { + return merger.getDimLevelType(b); + } + bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); } + + // Delegate methods to loop emitter. + Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); } + const std::vector &getValBuffer() const { + return loopEmitter->getValBuffer(); + } + + // Convenience method to slice topsort. + ArrayRef getTopSortSlice(size_t n, size_t m) const { + return ArrayRef(topSort).slice(n, m); + } + + // Convenience method to get current loop stack. + ArrayRef getLoopCurStack() const { + return getTopSortSlice(0, loopEmitter->getCurrentDepth()); + } + + // Convenience method to get the IV of the given loop index. + Value getLoopIdxValue(size_t loopIdx) const { + for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++) + if (topSort[lv] == loopIdx) + return getLoopIV(lv); + llvm_unreachable("invalid loop index"); } + + // TODO: make private + + /// Linalg operation. + linalg::GenericOp linalgOp; /// Sparsification options. SparsificationOptions options; - /// Loop emitter helper class. - SparseTensorLoopEmitter loopEmitter; + // Topological sort. + std::vector topSort; + /// Merger helper class. + Merger merger; + /// Loop emitter helper class (keep reference in scope!). + /// TODO: move emitter constructor up in time? + SparseTensorLoopEmitter *loopEmitter; /// Current reduction, updated during code generation. When indices of a /// reduction are exhausted, all inner loops can use a scalarized reduction. - unsigned redExp = -1u; + unsigned redExp; Value redVal; - Reduction redKind = kNoReduc; - unsigned redCustom = -1u; - // Sparse tensor as output. Implemented either through direct injective - // insertion in lexicographic index order or through access pattern expansion - // in the innermost loop nest (`expValues` through `expCount`). + Reduction redKind; + unsigned redCustom; + /// Sparse tensor as output. Implemented either through direct injective + /// insertion in lexicographic index order or through access pattern expansion + /// in the innermost loop nest (`expValues` through `expCount`). OpOperand *sparseOut; unsigned outerParNest; Value insChain; // bookkeeping for insertion chain @@ -87,21 +141,6 @@ Value expFilled; Value expAdded; Value expCount; - // Topsort (reference should remain in scope). - std::vector &topSort; - - ArrayRef getLoopCurStack() const { - ArrayRef topSortRef = topSort; - return topSortRef.slice(0, loopEmitter.getCurrentDepth()); - } - - Value getLoopIdxValue(size_t loopIdx) const { - for (unsigned lv = 0; lv < topSort.size(); lv++) - if (topSort[lv] == loopIdx) - return loopEmitter.getLoopIV(lv); - - llvm_unreachable("invalid loop index"); - } }; /// A helper class that visits an affine expression and tries to find an @@ -133,6 +172,7 @@ /// The mapping between dim=>iterator type. SmallVector iterTypes; }; + } // namespace //===----------------------------------------------------------------------===// @@ -172,17 +212,17 @@ } /// Determines if affine expression is invariant. -static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, - unsigned ldx, bool &atLevel) { - return isInvariantAffine(a, codegen.getLoopCurStack(), ldx, atLevel); +static bool isInvariantAffine(CodeGenEnv &env, AffineExpr a, unsigned ldx, + bool &atLevel) { + return isInvariantAffine(a, env.getLoopCurStack(), ldx, atLevel); } /// Helper method to construct a permuted dimension ordering /// that adheres to the given topological sort. -static AffineMap permute(const Merger &merger, MLIRContext *context, - AffineMap m, ArrayRef topSort) { - assert(m.getNumDims() + merger.getNumFilterLoops() == topSort.size() && - "TopoSort/AffineMap size mismatch"); +static AffineMap permute(CodeGenEnv &env, AffineMap m) { + assert(m.getNumDims() + env.merger.getNumFilterLoops() == + env.topSort.size() && + "size mismatch"); // Construct the inverse of `m`; to avoid the asymptotic complexity // of calling `m.getPermutedPosition` repeatedly. SmallVector perm; @@ -191,13 +231,14 @@ unsigned loopDepth = 1; // Construct the permutation. - while (worklist.any() && loopDepth <= topSort.size()) { + while (worklist.any() && loopDepth <= env.topSort.size()) { unsigned preSize = perm.size(); for (auto dim : worklist.set_bits()) { bool atLevel = false; if (m.getResult(dim).isa() || - (isInvariantAffine(m.getResult(dim), topSort.slice(0, loopDepth), - topSort[loopDepth - 1], atLevel) && + (isInvariantAffine(m.getResult(dim), + env.getTopSortSlice(0, loopDepth), + env.topSort[loopDepth - 1], atLevel) && atLevel)) { // If the matching affine is constant expression or just become // invariant. We can visit the dimension now without breaking the @@ -215,7 +256,7 @@ } assert(perm.size() == numResults); - return AffineMap::getPermutationMap(perm, context); + return AffineMap::getPermutationMap(perm, env.linalgOp.getContext()); } /// Helper method to inspect affine expressions. Rejects cases where the @@ -305,24 +346,24 @@ /// Returns true if the sparse annotations and affine subscript /// expressions of all tensors are admissible. Returns false if /// no annotations are found or inadmissible constructs occur. -static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { +static bool findSparseAnnotations(CodeGenEnv &env) { bool annotated = false; - unsigned filterLdx = merger.getFilterLoopStartingIdx(); - for (OpOperand &t : op->getOpOperands()) { - auto map = op.getMatchingIndexingMap(&t); + unsigned filterLdx = env.merger.getFilterLoopStartingIdx(); + for (OpOperand &t : env.linalgOp->getOpOperands()) { + auto map = env.linalgOp.getMatchingIndexingMap(&t); auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; - assert(map.getNumResults() == op.getRank(&t)); - + assert(map.getNumResults() == env.linalgOp.getRank(&t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t.getOperandNumber(); AffineExpr a = map.getResult(toOrigDim(enc, d)); - if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d), filterLdx)) + if (!findAffine(env.merger, tensor, d, a, getDimLevelType(enc, d), + filterLdx)) return false; // inadmissible affine expression } } - assert(filterLdx == merger.getNumLoops()); + assert(filterLdx == env.merger.getNumLoops()); return annotated; } @@ -330,9 +371,8 @@ /// as we use adj matrix for the graph. /// The sorted result will put the first Reduction iterator to the /// latest possible index. -static bool topSortOptimal(unsigned n, +static bool topSortOptimal(CodeGenEnv &env, unsigned n, ArrayRef iteratorTypes, - const Merger &merger, std::vector &topSort, std::vector &inDegree, std::vector> &adjM) { std::vector redIt; // reduce iterator with 0 degree @@ -340,7 +380,7 @@ std::vector filterIt; // filter loop with 0 degree for (unsigned i = 0; i < n; i++) { if (inDegree[i] == 0) { - if (merger.isFilterLoop(i)) + if (env.isFilterLoop(i)) filterIt.push_back(i); else if (linalg::isReductionIterator(iteratorTypes[i])) redIt.push_back(i); @@ -371,12 +411,12 @@ // O(X) computation => O(NK+NMX) time complexity auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt); auto src = it.back(); - topSort.push_back(src); + env.topSort.push_back(src); it.pop_back(); // Update in-degree, and push 0-degree node into worklist. for (unsigned dst = 0; dst < n; dst++) { if (adjM[src][dst] && --inDegree[dst] == 0) { - if (merger.isFilterLoop(dst)) + if (env.isFilterLoop(dst)) filterIt.push_back(dst); else if (linalg::isReductionIterator(iteratorTypes[dst])) redIt.push_back(dst); @@ -385,7 +425,7 @@ } } } - return topSort.size() == n; + return env.topSort.size() == n; } /// Helper method to add all constraints from the indices in one affine @@ -477,21 +517,21 @@ /// along fixed dimensions. Even for dense storage formats, however, the /// natural index order yields innermost unit-stride access with better /// spatial locality. -static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, - std::vector &topSort, unsigned mask, +static bool computeIterationGraph(CodeGenEnv &env, unsigned mask, OpOperand *skip = nullptr) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. - unsigned n = merger.getNumLoops(); + unsigned n = env.merger.getNumLoops(); std::vector> adjM(n, std::vector(n, false)); std::vector inDegree(n, 0); // in-degree of each node. - auto iteratorTypes = op.getIteratorTypesArray(); + auto iteratorTypes = env.linalgOp.getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. - for (OpOperand &t : op->getOpOperands()) { + for (OpOperand &t : env.linalgOp->getOpOperands()) { // Get map and encoding. - auto map = op.getMatchingIndexingMap(&t); + auto map = env.linalgOp.getMatchingIndexingMap(&t); auto enc = getSparseTensorEncoding(t.get().getType()); - assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(op) == n); + assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.linalgOp) == + n); // Skip dense tensor constraints when not requested. if (!(mask & SortMask::kIncludeDense) && !enc) continue; @@ -501,11 +541,11 @@ // on the loop indices if no explicit dimension ordering is given. for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr ta = map.getResult(toOrigDim(enc, d)); - Optional tldx = merger.getLoopIdx(t.getOperandNumber(), d); + Optional tldx = env.merger.getLoopIdx(t.getOperandNumber(), d); // Filter loops should be constructed after all the dependent loops, // i.e., d0 + d1 < filter_loop(d0 + d1) - if (tldx && merger.isFilterLoop(*tldx)) { + if (tldx && env.isFilterLoop(*tldx)) { assert(!ta.isa() && !isDenseDLT(getDimLevelType(enc, d))); addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, @@ -525,7 +565,7 @@ if (d > 0) { AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); Optional fldx = - merger.getLoopIdx(t.getOperandNumber(), d - 1); + env.merger.getLoopIdx(t.getOperandNumber(), d - 1); // Applying order constraints on every pair of dimExpr between two // compound affine expressions can sometime too strict: @@ -533,7 +573,7 @@ // It is totally fine to have loop sequence d0->d2->d1->d3 instead of // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. if (!(mask & SortMask::kIncludeDense)) - tryLoosenAffineDenseConstraints(op, fldx, fa, tldx, ta); + tryLoosenAffineDenseConstraints(env.linalgOp, fldx, fa, tldx, ta); // (d0 + d1) < (d2 + d3), or // filter_loop_d-1 < (d2 + d3), or @@ -548,24 +588,24 @@ if (mask & SortMask::kIncludeUndef) { unsigned tensor = t.getOperandNumber(); for (unsigned i = 0; i < n; i++) - if (isCompressedDLT(merger.getDimLevelType(tensor, i)) || - isSingletonDLT(merger.getDimLevelType(tensor, i))) { + if (isCompressedDLT(env.dimLevelType(tensor, i)) || + isSingletonDLT(env.dimLevelType(tensor, i))) { for (unsigned j = 0; j < n; j++) - if (isUndefDLT(merger.getDimLevelType(tensor, j))) { + if (isUndefDLT(env.dimLevelType(tensor, j))) { adjM[i][j] = true; inDegree[j]++; } } else { - assert(isDenseDLT(merger.getDimLevelType(tensor, i)) || - isUndefDLT(merger.getDimLevelType(tensor, i))); + assert(isDenseDLT(env.dimLevelType(tensor, i)) || + isUndefDLT(env.dimLevelType(tensor, i))); } } } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. - topSort.clear(); - topSort.reserve(n); - return topSortOptimal(n, iteratorTypes, merger, topSort, inDegree, adjM); + env.topSort.clear(); + env.topSort.reserve(n); + return topSortOptimal(env, n, iteratorTypes, inDegree, adjM); } /// Returns true if tensor materializes uninitialized into the computation. @@ -574,29 +614,25 @@ val.getDefiningOp(); } -/// Returns true when the tensor expression is admissible for codegen. +/// Returns true when the tensor expression is admissible for env. /// Since all sparse input tensors are admissible, we just need to check /// whether the out tensor in the tensor expression codegen is admissible. /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective /// nesting depth when a "truly dynamic" sparse tensor output occurs. -static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op, - std::vector &topSort, unsigned exp, - OpOperand **sparseOut, - unsigned &outerParNest) { +static bool isAdmissibleTensorExp(CodeGenEnv &env, unsigned exp) { // We reject any expression that makes a reduction from `-outTensor`, as those - // expression create dependency between the current iteration (i) and the - // previous iteration (i-1). It would then require iterating over the whole - // coordinate space, which prevent us from exploiting sparsity for faster - // code. - for (utils::IteratorType it : op.getIteratorTypesArray()) { + // expressions create a dependency between the current iteration (i) and the + // previous iteration (i-1). It would require iterating over the whole + // coordinate space, which prevent exploiting sparsity for faster code. + for (utils::IteratorType it : env.linalgOp.getIteratorTypesArray()) { if (it == utils::IteratorType::reduction) { - if (merger.hasNegateOnOut(exp)) + if (env.merger.hasNegateOnOut(exp)) return false; break; } } - OpOperand *lhs = op.getDpsInitOperand(0); + OpOperand *lhs = env.linalgOp.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); // An non-annotated output tensor is assumed dense, and becomes a random @@ -606,40 +642,41 @@ // An all-dense annotated "sparse" output tensor becomes a linearized random // access 1-dim memref. Also admissible since insertions cannot occur. bool allDense = true; - unsigned numLoops = merger.getNumLoops(); // numNativeLoops + numFilterLoops - for (unsigned i = 0; i < merger.getNumLoops(); i++) - if (isCompressedDLT(merger.getDimLevelType(tensor, i)) || - isSingletonDLT(merger.getDimLevelType(tensor, i))) { + unsigned numLoops = + env.merger.getNumLoops(); // numNativeLoops + numFilterLoops + for (unsigned i = 0; i < env.merger.getNumLoops(); i++) + if (isCompressedDLT(env.dimLevelType(tensor, i)) || + isSingletonDLT(env.dimLevelType(tensor, i))) { allDense = false; break; } else { - assert(isDenseDLT(merger.getDimLevelType(tensor, i)) || - isUndefDLT(merger.getDimLevelType(tensor, i))); + assert(isDenseDLT(env.dimLevelType(tensor, i)) || + isUndefDLT(env.dimLevelType(tensor, i))); } if (allDense) return true; // TODO: support compound affine expression on sparse output. - if (getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(lhs), + if (getNumCompoundAffineOnSparseDims(env.linalgOp.getMatchingIndexingMap(lhs), lhs->get()) != 0) return false; // A tensor expression with a sparse output tensor that changes its values // but not its nonzero structure, an operation called "simply dynamic" in - // [Bik96,Ch9], is also admissible without special codegen. - if (merger.isSingleCondition(tensor, exp)) + // [Bik96,Ch9], is also admissible without special env. + if (env.merger.isSingleCondition(tensor, exp)) return true; // Accept "truly dynamic" if the output tensor materializes uninitialized // into the computation and insertions occur in lexicographic index order. if (isMaterializing(lhs->get())) { - auto iteratorTypes = op.getIteratorTypesArray(); + auto iteratorTypes = env.linalgOp.getIteratorTypesArray(); unsigned nest = 0; for (unsigned i = 0; i < numLoops; i++) { - if (!merger.isFilterLoop(topSort[i])) { + if (!env.isFilterLoop(env.topSort[i])) { // We only count non-filter loops as filter loops should be considered // as a special type of parallel loops. - if (linalg::isReductionIterator(iteratorTypes[topSort[i]])) + if (linalg::isReductionIterator(iteratorTypes[env.topSort[i]])) break; // terminate at first reduction nest++; } @@ -647,9 +684,9 @@ // Determine admissible dynamic insertion situations: // (1) fully injective, since there are no reductions, // (2) admissible 1-d expansion in innermost dimension. - if (nest >= op.getRank(lhs) - 1) { - *sparseOut = lhs; - outerParNest = nest; + if (nest >= env.linalgOp.getRank(lhs) - 1) { + env.sparseOut = lhs; + env.outerParNest = nest; return true; } } @@ -688,9 +725,9 @@ } /// Updates scalarized reduction value. -static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { - assert(codegen.redKind != kNoReduc); - codegen.redVal = merger.exp(codegen.redExp).val = reduc; +static void updateReduc(CodeGenEnv &env, Value reduc) { + assert(env.redKind != kNoReduc); + env.redVal = env.exp(env.redExp).val = reduc; } /// Extracts identity from custom reduce. @@ -705,38 +742,38 @@ /// Generates loop boundary statements (entering/exiting loops). The function /// passes and updates the reduction value. static Optional genLoopBoundary( - CodeGen &codegen, Merger &merger, + CodeGenEnv &env, function_ref(MutableArrayRef reduc)> callback) { SmallVector reduc; - if (codegen.redVal) - reduc.push_back(codegen.redVal); - if (codegen.expValues) - reduc.push_back(codegen.expCount); - if (codegen.insChain) - reduc.push_back(codegen.insChain); + if (env.redVal) + reduc.push_back(env.redVal); + if (env.expValues) + reduc.push_back(env.expCount); + if (env.insChain) + reduc.push_back(env.insChain); auto r = callback(reduc); // Callback should do in-place update on reduction value vector. unsigned i = 0; - if (codegen.redVal) - updateReduc(merger, codegen, reduc[i++]); - if (codegen.expValues) - codegen.expCount = reduc[i++]; - if (codegen.insChain) - codegen.insChain = reduc[i]; + if (env.redVal) + updateReduc(env, reduc[i++]); + if (env.expValues) + env.expCount = reduc[i++]; + if (env.insChain) + env.insChain = reduc[i]; return r; } /// Local bufferization of all dense and sparse data structures. -static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op) { +static void genBuffers(CodeGenEnv &env, OpBuilder &builder) { + linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); assert(op.getNumOperands() == op.getNumDpsInputs() + 1); - codegen.loopEmitter.initializeLoopEmit( + env.loopEmitter->initializeLoopEmit( builder, loc, /// Generates buffer for the output tensor. /// Note that all sparse kernels assume that when all elements are written @@ -749,10 +786,9 @@ Value tensor) -> Value { // Must not be a sparse tensor. assert(!getSparseTensorEncoding(tensor.getType())); + // Two output tensor references should point to the same object. OpOperand *lhs = op.getDpsInitOperand(0); - // Two output tensors references should pointed to the same object. assert(lhs->get() == tensor); - bool isInit = op.isInitTensor(lhs); // An output tensor can simply materialize from the buffer of the tensor // that appears in the outs() clause. For updates, this has the // advantage that only the nonzero value are involved in the @@ -761,6 +797,7 @@ // may negatively impact running complexity (viz. O(n^2 + nnz) vs. // O(nnz) for matrices). // TODO: use better analysis to avoid zeroing out the buffer? + bool isInit = op.isInitTensor(lhs); Value init = memref; if (!isInit) { Value zero = constantZero(builder, loc, @@ -773,83 +810,82 @@ } /// Generates index for load/store on sparse tensor. -static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { - auto map = op.getMatchingIndexingMap(t); +static Value genIndex(CodeGenEnv &env, OpOperand *t) { + auto map = env.linalgOp.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); assert(a.getKind() == AffineExprKind::DimId); unsigned idx = a.cast().getPosition(); - return codegen.getLoopIdxValue(idx); + return env.getLoopIdxValue(idx); } /// Generates subscript for load/store on a dense or sparse tensor. -static Value genSubscript(CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, OpOperand *t, +static Value genSubscript(CodeGenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl &args) { + linalg::GenericOp op = env.linalgOp; unsigned tensor = t->getOperandNumber(); auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); unsigned rank = map.getNumResults(); if (enc) { - Value pidx = codegen.loopEmitter.getPidxs()[tensor].back(); + Value pidx = env.loopEmitter->getPidxs()[tensor].back(); assert(pidx); args.push_back(pidx); // position index } else { for (unsigned d = 0; d < rank; d++) { AffineExpr a = map.getResult(d); - args.push_back(codegen.loopEmitter.genAffine(builder, a, op.getLoc())); + args.push_back(env.loopEmitter->genAffine(builder, a, op.getLoc())); } } - return codegen.loopEmitter.getValBuffer()[tensor]; + return env.getValBuffer()[tensor]; } /// Generates insertion code to implement dynamic tensor load. -static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, OpOperand *t) { +static Value genInsertionLoad(CodeGenEnv &env, OpBuilder &builder, + OpOperand *t) { + linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); // Direct lexicographic index order, tensor loads as zero. - if (!codegen.expValues) { + if (!env.expValues) { Type tp = getElementTypeOrSelf(t->get().getType()); return constantZero(builder, loc, tp); } // Load from expanded access pattern. - Value index = genIndex(codegen, op, t); - return builder.create(loc, codegen.expValues, index); + Value index = genIndex(env, t); + return builder.create(loc, env.expValues, index); } /// Generates insertion code to implement dynamic tensor load for reduction. -static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen, - OpBuilder &builder, linalg::GenericOp op, +static Value genInsertionLoadReduce(CodeGenEnv &env, OpBuilder &builder, OpOperand *t) { + linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); - Value identity = getCustomRedId(merger.exp(codegen.redCustom).op); + Value identity = getCustomRedId(env.exp(env.redCustom).op); // Direct lexicographic index order, tensor loads as identity. - if (!codegen.expValues) { + if (!env.expValues) { return identity; } // Load from expanded access pattern if filled, identity otherwise. - Value index = genIndex(codegen, op, t); - Value isFilled = - builder.create(loc, codegen.expFilled, index); - Value valAtIndex = - builder.create(loc, codegen.expValues, index); + Value index = genIndex(env, t); + Value isFilled = builder.create(loc, env.expFilled, index); + Value valAtIndex = builder.create(loc, env.expValues, index); return builder.create(loc, isFilled, valAtIndex, identity); } /// Generates insertion code to implement dynamic tensor store. -static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, OpOperand *t, Value rhs) { +static void genInsertionStore(CodeGenEnv &env, OpBuilder &builder, OpOperand *t, + Value rhs) { + linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); // Direct insertion in lexicographic index order. - if (!codegen.expValues) { + if (!env.expValues) { unsigned rank = op.getRank(t); SmallVector indices; for (unsigned i = 0; i < rank; i++) { - assert(codegen.loopEmitter.getLoopIV(i)); - indices.push_back(codegen.loopEmitter.getLoopIV(i)); + assert(env.getLoopIV(i)); + indices.push_back(env.getLoopIV(i)); } - codegen.insChain = - builder.create(loc, rhs, codegen.insChain, indices); + env.insChain = builder.create(loc, rhs, env.insChain, indices); return; } // Generates insertion code along expanded access pattern. @@ -858,110 +894,108 @@ // expAdded[inserts++] = i // endif // values[i] = rhs - Value index = genIndex(codegen, op, t); + Value index = genIndex(env, t); Value fval = constantI1(builder, loc, false); Value tval = constantI1(builder, loc, true); // If statement. - Value filled = builder.create(loc, codegen.expFilled, index); + Value filled = builder.create(loc, env.expFilled, index); Value cond = builder.create(loc, arith::CmpIPredicate::eq, filled, fval); scf::IfOp ifOp = builder.create(loc, builder.getIndexType(), cond, /*else=*/true); // True branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - builder.create(loc, tval, codegen.expFilled, index); - builder.create(loc, index, codegen.expAdded, - codegen.expCount); + builder.create(loc, tval, env.expFilled, index); + builder.create(loc, index, env.expAdded, env.expCount); Value one = constantIndex(builder, loc, 1); - Value add = builder.create(loc, codegen.expCount, one); + Value add = builder.create(loc, env.expCount, one); builder.create(loc, add); // False branch. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, codegen.expCount); + builder.create(loc, env.expCount); builder.setInsertionPointAfter(ifOp); // Value assignment. - codegen.expCount = ifOp.getResult(0); - builder.create(loc, rhs, codegen.expValues, index); + env.expCount = ifOp.getResult(0); + builder.create(loc, rhs, env.expValues, index); } /// Generates a load on a dense or sparse tensor. -static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned exp) { +static Value genTensorLoad(CodeGenEnv &env, OpBuilder &builder, unsigned exp) { // Test if the load was hoisted to a higher loop nest. - Value val = merger.exp(exp).val; + Value val = env.exp(exp).val; if (val) return val; // Load during insertion. - OpOperand &t = op->getOpOperand(merger.exp(exp).tensor); - if (&t == codegen.sparseOut) { - if (codegen.redCustom != -1u) - return genInsertionLoadReduce(merger, codegen, builder, op, &t); - return genInsertionLoad(codegen, builder, op, &t); + linalg::GenericOp op = env.linalgOp; + OpOperand &t = op->getOpOperand(env.exp(exp).tensor); + if (&t == env.sparseOut) { + if (env.redCustom != -1u) + return genInsertionLoadReduce(env, builder, &t); + return genInsertionLoad(env, builder, &t); } // Actual load. SmallVector args; - Value ptr = genSubscript(codegen, builder, op, &t, args); + Value ptr = genSubscript(env, builder, &t, args); return builder.create(op.getLoc(), ptr, args); } /// Generates a store on a dense or sparse tensor. -static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned exp, Value rhs) { +static void genTensorStore(CodeGenEnv &env, OpBuilder &builder, unsigned exp, + Value rhs) { + linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); // Test if this is a scalarized reduction. - if (codegen.redVal) { - updateReduc(merger, codegen, rhs); + if (env.redVal) { + updateReduc(env, rhs); return; } // Store during insertion. OpOperand *t = op.getDpsInitOperand(0); - if (t == codegen.sparseOut) { + if (t == env.sparseOut) { if (!rhs) { // Only unary and binary are allowed to return uninitialized rhs // to indicate missing output. - assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); - } else if (merger.exp(exp).kind == kSelect) { + assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary); + } else if (env.exp(exp).kind == kSelect) { // Select operation insertion. - Value insChain = codegen.insChain; + Value insChain = env.insChain; assert(insChain); scf::IfOp ifOp = builder.create(loc, insChain.getType(), rhs, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Existing value was preserved to be used here. - assert(merger.exp(exp).val); - Value v0 = merger.exp(exp).val; - genInsertionStore(codegen, builder, op, t, v0); - merger.exp(exp).val = Value(); + assert(env.exp(exp).val); + Value v0 = env.exp(exp).val; + genInsertionStore(env, builder, t, v0); + env.exp(exp).val = Value(); // Yield modified insertion chain along true branch. - builder.create(op.getLoc(), codegen.insChain); + builder.create(op.getLoc(), env.insChain); // Yield original insertion chain along false branch. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); builder.create(loc, insChain); // Done with if statement. - codegen.insChain = ifOp->getResult(0); + env.insChain = ifOp->getResult(0); builder.setInsertionPointAfter(ifOp); } else { - genInsertionStore(codegen, builder, op, t, rhs); + genInsertionStore(env, builder, t, rhs); } return; } // Actual store. SmallVector args; - Value ptr = genSubscript(codegen, builder, op, t, args); + Value ptr = genSubscript(env, builder, t, args); builder.create(loc, rhs, ptr, args); } /// Generates an invariant value. -inline static Value genInvariantValue(Merger &merger, CodeGen &codegen, - OpBuilder &builder, unsigned exp) { - return merger.exp(exp).val; +inline static Value genInvariantValue(CodeGenEnv &env, unsigned exp) { + return env.exp(exp).val; } /// Generates an index value. -inline static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, - unsigned idx) { - return codegen.getLoopIdxValue(idx); +inline static Value genIndexValue(CodeGenEnv &env, unsigned idx) { + return env.getLoopIdxValue(idx); } /// Semi-ring branches are simply inlined by the sparse compiler. Prior @@ -969,86 +1003,84 @@ /// branch or otherwise invariantly defined outside the loop nest, with the /// exception of index computations, which need to be relinked to actual /// inlined cloned code. -static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, - Block *block, Value e, unsigned ldx) { +static Value relinkBranch(CodeGenEnv &env, RewriterBase &rewriter, Block *block, + Value e, unsigned ldx) { if (Operation *def = e.getDefiningOp()) { if (auto indexOp = dyn_cast(def)) - return genIndexValue(codegen, rewriter, indexOp.getDim()); + return genIndexValue(env, indexOp.getDim()); if (def->getBlock() == block) { for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) def->setOperand( - i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx)); + i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx)); } } return e; } /// Recursively generates tensor expression. -static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, - linalg::GenericOp op, unsigned exp, unsigned ldx) { +static Value genExp(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp, + unsigned ldx) { + linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); + if (exp == -1u) return Value(); - if (merger.exp(exp).kind == Kind::kTensor) - return genTensorLoad(merger, codegen, rewriter, op, exp); - if (merger.exp(exp).kind == Kind::kInvariant) - return genInvariantValue(merger, codegen, rewriter, exp); - if (merger.exp(exp).kind == Kind::kIndex) - return genIndexValue(codegen, rewriter, merger.exp(exp).index); - - if (merger.exp(exp).kind == Kind::kReduce) { - // Make custom reduction identity accessible for expanded access pattern. - assert(codegen.redCustom == -1u); - codegen.redCustom = exp; + if (env.exp(exp).kind == Kind::kTensor) + return genTensorLoad(env, rewriter, exp); + if (env.exp(exp).kind == Kind::kInvariant) + return genInvariantValue(env, exp); + if (env.exp(exp).kind == Kind::kIndex) + return genIndexValue(env, env.exp(exp).index); + + // Make custom reduction identity accessible for expanded access pattern. + if (env.exp(exp).kind == Kind::kReduce) { + assert(env.redCustom == -1u); + env.redCustom = exp; } - Value v0 = - genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); - Value v1 = - genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); - Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); - if (ee && (merger.exp(exp).kind == Kind::kUnary || - merger.exp(exp).kind == Kind::kBinary || - merger.exp(exp).kind == Kind::kBinaryBranch || - merger.exp(exp).kind == Kind::kReduce || - merger.exp(exp).kind == Kind::kSelect)) - ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); - - if (merger.exp(exp).kind == kSelect) { - assert(!merger.exp(exp).val); - merger.exp(exp).val = v0; // Preserve value for later use. - } - - if (merger.exp(exp).kind == Kind::kReduce) { - assert(codegen.redCustom != -1u); - codegen.redCustom = -1u; + Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx); + Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx); + Value ee = env.merger.buildExp(rewriter, loc, exp, v0, v1); + if (ee && (env.exp(exp).kind == Kind::kUnary || + env.exp(exp).kind == Kind::kBinary || + env.exp(exp).kind == Kind::kBinaryBranch || + env.exp(exp).kind == Kind::kReduce || + env.exp(exp).kind == Kind::kSelect)) + ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); + + if (env.exp(exp).kind == kSelect) { + assert(!env.exp(exp).val); + env.exp(exp).val = v0; // Preserve value for later use. + } else if (env.exp(exp).kind == Kind::kReduce) { + assert(env.redCustom != -1u); + env.redCustom = -1u; } return ee; } /// Hoists loop invariant tensor loads for which indices have been exhausted. -static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned exp, unsigned ldx, - bool atStart, unsigned last = -1u) { +static void genInvariants(CodeGenEnv &env, OpBuilder &builder, unsigned exp, + unsigned ldx, bool atStart, unsigned last = -1u) { if (exp == -1u) return; - if (merger.exp(exp).kind == Kind::kTensor) { + if (env.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand &t = op->getOpOperand(merger.exp(exp).tensor); + linalg::GenericOp op = env.linalgOp; + OpOperand &t = op->getOpOperand(env.exp(exp).tensor); auto map = op.getMatchingIndexingMap(&t); auto enc = getSparseTensorEncoding(t.get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(toOrigDim(enc, d)); - Optional sldx = merger.getLoopIdx(t.getOperandNumber(), d); - if (sldx && merger.isFilterLoop(*sldx)) { - if (!codegen.getLoopIdxValue(*sldx)) + Optional sldx = env.merger.getLoopIdx(t.getOperandNumber(), d); + if (sldx && env.isFilterLoop(*sldx)) { + if (!env.getLoopIdxValue(*sldx)) // The filter loops has not been constructed. return; if (*sldx == ldx) atLevel = true; - } else if (!isInvariantAffine(codegen, a, ldx, atLevel)) + } else if (!isInvariantAffine(env, a, ldx, atLevel)) return; // still in play } // All exhausted at this level (atLevel denotes exactly at this level). @@ -1058,45 +1090,43 @@ if (lhs == &t) { // Start or end a scalarized reduction if (atStart) { - Kind kind = merger.exp(last).kind; - Value load = kind == Kind::kReduce - ? getCustomRedId(merger.exp(last).op) - : genTensorLoad(merger, codegen, builder, op, exp); - codegen.redKind = getReduction(kind); - codegen.redExp = exp; - updateReduc(merger, codegen, load); + Kind kind = env.exp(last).kind; + Value load = kind == Kind::kReduce ? getCustomRedId(env.exp(last).op) + : genTensorLoad(env, builder, exp); + env.redKind = getReduction(kind); + env.redExp = exp; + updateReduc(env, load); } else { - Value redVal = codegen.redVal; - updateReduc(merger, codegen, Value()); - codegen.redExp = -1u; - codegen.redKind = kNoReduc; - genTensorStore(merger, codegen, builder, op, exp, redVal); + Value redVal = env.redVal; + updateReduc(env, Value()); + env.redExp = -1u; + env.redKind = kNoReduc; + genTensorStore(env, builder, exp, redVal); } } else { // Start or end loop invariant hoisting of a tensor load. - merger.exp(exp).val = - atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value(); + env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value(); } - } else if (merger.exp(exp).kind != Kind::kInvariant && - merger.exp(exp).kind != Kind::kIndex) { + } else if (env.exp(exp).kind != Kind::kInvariant && + env.exp(exp).kind != Kind::kIndex) { // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - unsigned e0 = merger.exp(exp).children.e0; - unsigned e1 = merger.exp(exp).children.e1; - genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp); - genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp); + unsigned e0 = env.exp(exp).children.e0; + unsigned e1 = env.exp(exp).children.e1; + genInvariants(env, builder, e0, ldx, atStart, exp); + genInvariants(env, builder, e1, ldx, atStart, exp); } } /// Generates an expanded access pattern in innermost dimension. -static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned at, bool atStart) { - OpOperand *lhs = codegen.sparseOut; - if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || - at != codegen.outerParNest) +static void genExpansion(CodeGenEnv &env, OpBuilder &builder, unsigned at, + bool atStart) { + linalg::GenericOp op = env.linalgOp; + OpOperand *lhs = env.sparseOut; + if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest) return; // not needed at this level - assert(codegen.redVal == nullptr); + assert(env.redVal == nullptr); // Generate start or end of an expanded access pattern. Note that because // an expension does not rely on the ongoing contents of the sparse storage // scheme, we can use the original tensor as incoming SSA value (which @@ -1114,38 +1144,37 @@ auto res = builder.create(loc, TypeRange({t1, t2, t3, t4}), tensor); assert(res.getNumResults() == 4); - assert(!codegen.expValues); - codegen.expValues = res.getResult(0); - codegen.expFilled = res.getResult(1); - codegen.expAdded = res.getResult(2); - codegen.expCount = res.getResult(3); + assert(!env.expValues); + env.expValues = res.getResult(0); + env.expFilled = res.getResult(1); + env.expAdded = res.getResult(2); + env.expCount = res.getResult(3); } else { - assert(codegen.expValues); + assert(env.expValues); SmallVector indices; for (unsigned i = 0; i < at; i++) { - assert(codegen.loopEmitter.getLoopIV(i)); - indices.push_back(codegen.loopEmitter.getLoopIV(i)); + assert(env.getLoopIV(i)); + indices.push_back(env.getLoopIV(i)); } - codegen.insChain = builder.create( - loc, codegen.expValues, codegen.expFilled, codegen.expAdded, - codegen.expCount, codegen.insChain, indices); - codegen.expValues = codegen.expFilled = codegen.expAdded = - codegen.expCount = Value(); + env.insChain = builder.create(loc, env.expValues, env.expFilled, + env.expAdded, env.expCount, + env.insChain, indices); + env.expValues = env.expFilled = env.expAdded = env.expCount = Value(); } } /// Returns parallelization strategy. Any implicit loop in the Linalg /// operation that is marked "parallel" is a candidate. Whether it is actually /// converted to a parallel operation depends on the requested strategy. -static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isSparse) { +static bool isParallelFor(CodeGenEnv &env, bool isOuter, bool isSparse) { // Reject parallelization of sparse output. - if (codegen.sparseOut) + if (env.sparseOut) return false; // Parallel loops on tensor expansion can cause data races. - if (codegen.expCount) + if (env.expCount) return false; // Inspect strategy. - switch (codegen.options.parallelizationStrategy) { + switch (env.options.parallelizationStrategy) { case SparseParallelizationStrategy::kNone: return false; case SparseParallelizationStrategy::kDenseOuterLoop: @@ -1161,98 +1190,94 @@ } /// Generates a for-loop on a single index. -static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, bool isOuter, bool isInner, - unsigned idx, size_t tid, size_t dim, +static Operation *genFor(CodeGenEnv &env, OpBuilder &builder, bool isOuter, + bool isInner, unsigned idx, size_t tid, size_t dim, ArrayRef extraTids, ArrayRef extraDims) { + linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); - bool isSparse = isCompressedDLT(merger.getDimLevelType(tid, idx)) || - isSingletonDLT(merger.getDimLevelType(tid, idx)); - bool isParallel = isParallelFor(codegen, isOuter, isSparse); - - Operation *loop = - *genLoopBoundary(codegen, merger, [&](MutableArrayRef reduc) { - if (merger.isFilterLoop(idx)) { - // extraTids/extraDims must be empty because filter loops only - // corresponding to the one and only sparse tensor level. - assert(isSparse && extraTids.empty() && extraDims.empty()); - OpOperand *t = &op->getOpOperand(tid); - auto enc = getSparseTensorEncoding(t->get().getType()); - // Retrieves the affine expression for the filter loop. - AffineExpr a = - op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim)); - return codegen.loopEmitter.enterFilterLoopOverTensorAtDim( - builder, loc, tid, dim, a, reduc); - } - return codegen.loopEmitter.enterLoopOverTensorAtDim( - builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims); - }); + auto iteratorTypes = op.getIteratorTypesArray(); + bool isSparse = isCompressedDLT(env.dimLevelType(tid, idx)) || + isSingletonDLT(env.dimLevelType(tid, idx)); + bool isParallel = isParallelFor(env, isOuter, isSparse); + + Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef reduc) { + if (env.isFilterLoop(idx)) { + // extraTids/extraDims must be empty because filter loops only + // corresponding to the one and only sparse tensor level. + assert(isSparse && extraTids.empty() && extraDims.empty()); + OpOperand *t = &op->getOpOperand(tid); + auto enc = getSparseTensorEncoding(t->get().getType()); + // Retrieves the affine expression for the filter loop. + AffineExpr a = + op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim)); + return env.loopEmitter->enterFilterLoopOverTensorAtDim(builder, loc, tid, + dim, a, reduc); + } + return env.loopEmitter->enterLoopOverTensorAtDim( + builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims); + }); assert(loop); return loop; } /// Emit a while-loop for co-iteration over multiple indices. -static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned idx, bool needsUniv, - ArrayRef condTids, ArrayRef condDims, +static Operation *genWhile(CodeGenEnv &env, OpBuilder &builder, unsigned idx, + bool needsUniv, ArrayRef condTids, + ArrayRef condDims, ArrayRef extraTids, ArrayRef extraDims) { - - Operation *loop = - *genLoopBoundary(codegen, merger, [&](MutableArrayRef reduc) { - // Construct the while-loop with a parameter for each index. - return codegen.loopEmitter.enterCoIterationOverTensorsAtDims( - builder, op.getLoc(), condTids, condDims, needsUniv, reduc, - extraTids, extraDims); - }); + Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef reduc) { + // Construct the while-loop with a parameter for each index. + return env.loopEmitter->enterCoIterationOverTensorsAtDims( + builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc, + extraTids, extraDims); + }); assert(loop); return loop; } /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. -static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned at, bool needsUniv, - ArrayRef condTids, ArrayRef condDims, - ArrayRef extraTids, +static Operation *genLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at, + bool needsUniv, ArrayRef condTids, + ArrayRef condDims, ArrayRef extraTids, ArrayRef extraDims) { assert(condTids.size() == condDims.size()); assert(extraTids.size() == extraDims.size()); - unsigned idx = codegen.topSort[at]; + unsigned idx = env.topSort[at]; if (condTids.size() == 1) { bool isOuter = at == 0; - bool isInner = at == codegen.topSort.size() - 1; - return genFor(merger, codegen, builder, op, isOuter, isInner, idx, - condTids.front(), condDims.front(), extraTids, extraDims); + bool isInner = at == env.topSort.size() - 1; + return genFor(env, builder, isOuter, isInner, idx, condTids.front(), + condDims.front(), extraTids, extraDims); } - return genWhile(merger, codegen, builder, op, idx, needsUniv, condTids, - condDims, extraTids, extraDims); + return genWhile(env, builder, idx, needsUniv, condTids, condDims, extraTids, + extraDims); } /// Generates the induction structure for a while-loop. -static void finalizeWhileOp(Merger &merger, CodeGen &codegen, - OpBuilder &builder, linalg::GenericOp op, - unsigned idx, bool needsUniv, BitVector &induction, +static void finalizeWhileOp(CodeGenEnv &env, OpBuilder &builder, unsigned idx, + bool needsUniv, BitVector &induction, scf::WhileOp whileOp) { - Location loc = op.getLoc(); + Location loc = env.linalgOp.getLoc(); // Finalize each else branch of all if statements. - if (codegen.redVal || codegen.expValues || codegen.insChain) { + if (env.redVal || env.expValues || env.insChain) { while (auto ifOp = dyn_cast_or_null( builder.getInsertionBlock()->getParentOp())) { unsigned y = 0; SmallVector yields; - if (codegen.redVal) { - yields.push_back(codegen.redVal); - updateReduc(merger, codegen, ifOp.getResult(y++)); + if (env.redVal) { + yields.push_back(env.redVal); + updateReduc(env, ifOp.getResult(y++)); } - if (codegen.expValues) { - yields.push_back(codegen.expCount); - codegen.expCount = ifOp->getResult(y++); + if (env.expValues) { + yields.push_back(env.expCount); + env.expCount = ifOp->getResult(y++); } - if (codegen.insChain) { - yields.push_back(codegen.insChain); - codegen.insChain = ifOp->getResult(y++); + if (env.insChain) { + yields.push_back(env.insChain); + env.insChain = ifOp->getResult(y++); } assert(y == yields.size()); builder.create(loc, yields); @@ -1263,62 +1288,61 @@ } /// Generates a single if-statement within a while-loop. -static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned idx, +static scf::IfOp genIf(CodeGenEnv &env, OpBuilder &builder, unsigned idx, BitVector &conditions) { - Location loc = op.getLoc(); + Location loc = env.linalgOp.getLoc(); SmallVector types; Value cond; for (unsigned b = 0, be = conditions.size(); b < be; b++) { if (!conditions[b]) continue; - unsigned tensor = merger.tensor(b); - assert(idx == merger.index(b)); + unsigned tensor = env.merger.tensor(b); + assert(idx == env.merger.index(b)); Value clause; - if (isCompressedDLT(merger.getDimLevelType(b)) || - isSingletonDLT(merger.getDimLevelType(b))) { - auto dim = *merger.getDimNum(tensor, idx); - Value op1 = codegen.loopEmitter.getCoord()[tensor][dim]; - Value op2 = codegen.getLoopIdxValue(idx); + if (isCompressedDLT(env.dimLevelType(b)) || + isSingletonDLT(env.dimLevelType(b))) { + auto dim = *env.merger.getDimNum(tensor, idx); + Value op1 = env.loopEmitter->getCoord()[tensor][dim]; + Value op2 = env.getLoopIdxValue(idx); clause = builder.create(loc, arith::CmpIPredicate::eq, op1, op2); } else { - assert(isDenseDLT(merger.getDimLevelType(b)) || - isUndefDLT(merger.getDimLevelType(b))); + assert(isDenseDLT(env.dimLevelType(b)) || + isUndefDLT(env.dimLevelType(b))); clause = constantI1(builder, loc, true); } cond = cond ? builder.create(loc, cond, clause) : clause; } - if (codegen.redVal) - types.push_back(codegen.redVal.getType()); - if (codegen.expValues) + if (env.redVal) + types.push_back(env.redVal.getType()); + if (env.expValues) types.push_back(builder.getIndexType()); - if (codegen.insChain) - types.push_back(codegen.insChain.getType()); + if (env.insChain) + types.push_back(env.insChain.getType()); scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); return ifOp; } /// Generates end of true branch of if-statement within a while-loop. -static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, - Value redInput, Value cntInput, Value insInput) { +static void endIf(CodeGenEnv &env, OpBuilder &builder, scf::IfOp ifOp, + Operation *loop, Value redInput, Value cntInput, + Value insInput) { SmallVector operands; - if (codegen.redVal) { - operands.push_back(codegen.redVal); - updateReduc(merger, codegen, redInput); + if (env.redVal) { + operands.push_back(env.redVal); + updateReduc(env, redInput); } - if (codegen.expValues) { - operands.push_back(codegen.expCount); - codegen.expCount = cntInput; + if (env.expValues) { + operands.push_back(env.expCount); + env.expCount = cntInput; } - if (codegen.insChain) { - operands.push_back(codegen.insChain); - codegen.insChain = insInput; + if (env.insChain) { + operands.push_back(env.insChain); + env.insChain = insInput; } if (!operands.empty()) - builder.create(op.getLoc(), operands); + builder.create(env.linalgOp.getLoc(), operands); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); } @@ -1328,24 +1352,24 @@ /// Starts a loop sequence at given level. Returns true if /// the universal loop index must be maintained at this level. -static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned exp, unsigned at, - unsigned idx, unsigned ldx, unsigned lts) { - assert(!codegen.getLoopIdxValue(idx)); +static bool startLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp, + unsigned at, unsigned idx, unsigned ldx, + unsigned lts) { + assert(!env.getLoopIdxValue(idx)); // Emit invariants at this loop sequence level. - genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); + genInvariants(env, builder, exp, ldx, /*atStart=*/true); // Emit access pattern expansion for sparse tensor output. - genExpansion(merger, codegen, builder, op, at, /*atStart=*/true); + genExpansion(env, builder, at, /*atStart=*/true); // Emit further intitialization at this loop sequence level. - unsigned l0 = merger.set(lts)[0]; + unsigned l0 = env.set(lts)[0]; bool needsUniv = false; SmallVector tids; SmallVector dims; - merger.foreachTidDimPairInBits( - merger.lat(l0).bits, + env.merger.foreachTidDimPairInBits( + env.lat(l0).bits, [&](unsigned b, unsigned tid, Optional dim, DimLevelType dlt) { - assert(merger.index(b) == idx); + assert(env.merger.index(b) == idx); if (isDenseDLT(dlt) || isUndefDLT(dlt)) { needsUniv = true; } else { @@ -1355,28 +1379,27 @@ } }); - codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims); + env.loopEmitter->enterNewLoopSeq(builder, env.linalgOp.getLoc(), tids, dims); // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. if (needsUniv) { - unsigned lsize = merger.set(lts).size(); + unsigned lsize = env.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { - unsigned li = merger.set(lts)[i]; - if (!merger.hasAnySparse(merger.lat(li).simple)) + unsigned li = env.set(lts)[i]; + if (!env.merger.hasAnySparse(env.lat(li).simple)) return true; } } return false; } -static void genConstantDenseAddressFromLevel(CodeGen &codegen, - OpBuilder &builder, - linalg::GenericOp op, unsigned tid, +static void genConstantDenseAddressFromLevel(CodeGenEnv &env, + OpBuilder &builder, unsigned tid, unsigned lvl) { // TODO: Handle affine expression on output tensor. + linalg::GenericOp op = env.linalgOp; assert(tid < op.getNumDpsInputs()); - OpOperand *input = op.getDpsInputOperands()[tid]; ArrayRef affines = op.getMatchingIndexingMap(input).getResults(); auto enc = getSparseTensorEncoding(input->get().getType()); @@ -1384,42 +1407,38 @@ for (unsigned i = lvl, e = affines.size(); i < e; i++) { AffineExpr affine = affines[toOrigDim(enc, i)]; if (isDenseDLT(getDimLevelType(enc, i)) && - affine.isa()) { - codegen.loopEmitter.genDenseAffineAddressAtCurLevel( + affine.isa()) + env.loopEmitter->genDenseAffineAddressAtCurLevel( builder, op.getLoc(), input->getOperandNumber(), i, affine); - } else { - // Breaks on first non-dense non-constant level. - return; - } + else + return; // break on first non-dense non-constant level } } } -static void genInitConstantDenseAddress(CodeGen &codegen, - RewriterBase &rewriter, - linalg::GenericOp op) { - // We can generates address for constant affine expression before any loops +static void genInitConstantDenseAddress(CodeGenEnv &env, + RewriterBase &rewriter) { + // We can generate address for constant affine expression before any loops // starting from the first level as they do not depend on any thing. // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two // levels can be determined before loops. - for (unsigned tid = 0, e = op.getNumDpsInputs(); tid < e; tid++) - genConstantDenseAddressFromLevel(codegen, rewriter, op, tid, 0); + for (unsigned tid = 0, e = env.linalgOp.getNumDpsInputs(); tid < e; tid++) + genConstantDenseAddressFromLevel(env, rewriter, tid, 0); } static void translateBitsToTidDimPairs( - Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li, - unsigned idx, SmallVectorImpl &condTids, - SmallVectorImpl &condDims, SmallVectorImpl &extraTids, - SmallVectorImpl &extraDims, SmallVectorImpl &affineTids, - SmallVectorImpl &affineDims, SmallVectorImpl &exps) { - - const BitVector &all = merger.lat(li).bits; - const BitVector &simple = merger.lat(li).simple; + CodeGenEnv &env, unsigned li, unsigned idx, + SmallVectorImpl &condTids, SmallVectorImpl &condDims, + SmallVectorImpl &extraTids, SmallVectorImpl &extraDims, + SmallVectorImpl &affineTids, SmallVectorImpl &affineDims, + SmallVectorImpl &exps) { + const BitVector &all = env.lat(li).bits; + const BitVector &simple = env.lat(li).simple; // Converts bits to array + dim pair - merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid, - Optional dim, - DimLevelType dlt) { + env.merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid, + Optional dim, + DimLevelType dlt) { if (simple.test(b)) { if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to iterate based @@ -1428,8 +1447,8 @@ // output tensor). // out[i][j] = invariant; or a broadcast // out[i][j] = in[i] (j is undef for input) - tid = merger.getOutTensorID(); - dim = merger.getDimNum(tid, idx); + tid = env.merger.getOutTensorID(); + dim = env.merger.getDimNum(tid, idx); // Skips invalid dim (e.g., when this is a zero ranked tensor). if (!dim) return; @@ -1442,6 +1461,7 @@ extraDims.push_back(*dim); } else { assert(isUndefDLT(dlt)); + linalg::GenericOp op = env.linalgOp; if (tid >= op.getNumDpsInputs()) // We only handle affine expression on input tensors (for now). return; @@ -1464,7 +1484,7 @@ // Constant affine expression are handled in genLoop if (!exp.isa()) { bool atLevel = false; - if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) { + if (isInvariantAffine(env, exp, idx, atLevel) && atLevel) { // If the compound affine is invariant and we are right at the // level. We need to generate the address according to the affine // expression. This is also the best place we can do it to avoid @@ -1482,20 +1502,19 @@ } }); - if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) { + if (isDenseDLT(env.dimLevelType(env.merger.getOutTensorID(), idx))) { // Note that we generate dense indices of the output tensor // unconditionally, since they may not appear in the lattice, but may be // needed for linearized codegen. - auto dim = *merger.getDimNum(merger.getOutTensorID(), idx); - extraTids.push_back(merger.getOutTensorID()); + auto dim = *env.merger.getDimNum(env.merger.getOutTensorID(), idx); + extraTids.push_back(env.merger.getOutTensorID()); extraDims.push_back(dim); } } /// Starts a single loop in current sequence. -static Operation *startLoop(Merger &merger, CodeGen &codegen, - OpBuilder &builder, linalg::GenericOp op, - unsigned at, unsigned li, bool needsUniv) { +static Operation *startLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at, + unsigned li, bool needsUniv) { // The set of tensors + dims to generate loops on SmallVector condTids, condDims; // The set of (dense) tensors that is optimized from condition, yet still @@ -1506,17 +1525,16 @@ // level. SmallVector affineTids, affineDims; SmallVector affines; + translateBitsToTidDimPairs(env, li, env.topSort[at], condTids, condDims, + extraTids, extraDims, affineTids, affineDims, + affines); - translateBitsToTidDimPairs(merger, codegen, op, li, codegen.topSort[at], - condTids, condDims, extraTids, extraDims, - affineTids, affineDims, affines); // Emit the for/while-loop control. - Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv, - condTids, condDims, extraTids, extraDims); - + Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims, + extraTids, extraDims); for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { - codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(), - tid, dim, exp); + env.loopEmitter->genDenseAffineAddressAtCurLevel( + builder, env.linalgOp.getLoc(), tid, dim, exp); } // Until now, we have entered every pair in {cond, extra, @@ -1525,27 +1543,25 @@ auto allTids = llvm::concat(condTids, extraTids, affineTids); auto allDims = llvm::concat(condDims, extraDims, affineDims); for (auto [tid, dim] : llvm::zip(allTids, allDims)) { - if (tid != merger.getOutTensorID()) - genConstantDenseAddressFromLevel(codegen, builder, op, tid, dim + 1); + if (tid != env.merger.getOutTensorID()) + genConstantDenseAddressFromLevel(env, builder, tid, dim + 1); } return loop; } /// Ends a single loop in current sequence. Returns new values for needsUniv. -static bool endLoop(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, - linalg::GenericOp op, Operation *loop, unsigned idx, - unsigned li, bool needsUniv) { +static bool endLoop(CodeGenEnv &env, RewriterBase &rewriter, Operation *loop, + unsigned idx, unsigned li, bool needsUniv) { // End a while-loop. if (auto whileOp = dyn_cast(loop)) { - finalizeWhileOp(merger, codegen, rewriter, op, idx, needsUniv, - merger.lat(li).bits, whileOp); + finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp); } else { needsUniv = false; } - genLoopBoundary(codegen, merger, [&](MutableArrayRef reduc) { - codegen.loopEmitter.exitCurrentLoop(rewriter, op.getLoc(), reduc); + genLoopBoundary(env, [&](MutableArrayRef reduc) { + env.loopEmitter->exitCurrentLoop(rewriter, env.linalgOp.getLoc(), reduc); return std::nullopt; }); @@ -1553,85 +1569,79 @@ } /// Ends a loop sequence at given level. -static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, unsigned exp, unsigned at, - unsigned idx, unsigned ldx) { - assert(codegen.getLoopIdxValue(idx) == nullptr); - codegen.loopEmitter.exitCurrentLoopSeq(); +static void endLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp, + unsigned at, unsigned idx, unsigned ldx) { + assert(env.getLoopIdxValue(idx) == nullptr); + env.loopEmitter->exitCurrentLoopSeq(); // Unmark bookkeeping of invariants and loop index. - genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); + genInvariants(env, builder, exp, ldx, /*atStart=*/false); // Finalize access pattern expansion for sparse tensor output. - genExpansion(merger, codegen, builder, op, at, /*atStart=*/false); + genExpansion(env, builder, at, /*atStart=*/false); } /// Recursively generates code while computing iteration lattices in order /// to manage the complexity of implementing co-iteration over unions /// and intersections of sparse iterations spaces. -static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, - linalg::GenericOp op, unsigned exp, unsigned at) { +static void genStmt(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp, + unsigned at) { // At each leaf, assign remaining tensor (sub)expression to output tensor. - if (at == codegen.topSort.size()) { - unsigned ldx = codegen.topSort[at - 1]; - Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); - genTensorStore(merger, codegen, rewriter, op, exp, rhs); + if (at == env.topSort.size()) { + unsigned ldx = env.topSort[at - 1]; + Value rhs = genExp(env, rewriter, exp, ldx); + genTensorStore(env, rewriter, exp, rhs); return; } // Construct iteration lattices for current loop index, with L0 at top. - unsigned idx = codegen.topSort[at]; - unsigned ldx = at == 0 ? -1u : codegen.topSort[at - 1]; - unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); + unsigned idx = env.topSort[at]; + unsigned ldx = at == 0 ? -1u : env.topSort[at - 1]; + unsigned lts = env.merger.optimizeSet(env.merger.buildLattices(exp, idx)); // TODO: sort // TODO: dedup // Start a loop sequence. - bool needsUniv = - startLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx, lts); + bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts); // Emit a loop for every lattice point L0 >= Li in this loop sequence. - unsigned lsize = merger.set(lts).size(); + unsigned lsize = env.set(lts).size(); for (unsigned i = 0; i < lsize; i++) { // Start a loop. - unsigned li = merger.set(lts)[i]; - Operation *loop = - startLoop(merger, codegen, rewriter, op, at, li, needsUniv); + unsigned li = env.set(lts)[i]; + Operation *loop = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. - Value redInput = codegen.redVal; - Value cntInput = codegen.expCount; - Value insInput = codegen.insChain; + Value redInput = env.redVal; + Value cntInput = env.expCount; + Value insInput = env.insChain; bool isWhile = dyn_cast(loop) != nullptr; for (unsigned j = 0; j < lsize; j++) { - unsigned lj = merger.set(lts)[j]; - unsigned ej = merger.lat(lj).exp; - if (li == lj || merger.latGT(li, lj)) { + unsigned lj = env.set(lts)[j]; + unsigned ej = env.lat(lj).exp; + if (li == lj || env.merger.latGT(li, lj)) { // Recurse into body of each branch. if (isWhile) { - scf::IfOp ifOp = - genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); - genStmt(merger, codegen, rewriter, op, ej, at + 1); - endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput, - insInput); + scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple); + genStmt(env, rewriter, ej, at + 1); + endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput); } else { - genStmt(merger, codegen, rewriter, op, ej, at + 1); + genStmt(env, rewriter, ej, at + 1); } } } // End a loop. - needsUniv = - endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); + needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv); } // End a loop sequence. - endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); + endLoopSeq(env, rewriter, exp, at, idx, ldx); } /// Converts the result computed by the sparse kernel into the required form. -static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, - linalg::GenericOp op) { +static void genResult(CodeGenEnv &env, RewriterBase &rewriter) { + linalg::GenericOp op = env.linalgOp; OpOperand *lhs = op.getDpsInitOperand(0); Value tensor = lhs->get(); Type resType = tensor.getType(); @@ -1639,14 +1649,14 @@ // The sparse tensor rematerializes from the original sparse tensor's // underlying sparse storage format. For an insertion chain, the // tensor materializes from the chain with 'hasInserts' enabled. - bool hasInserts = codegen.sparseOut == lhs; + bool hasInserts = env.sparseOut == lhs; if (hasInserts) - tensor = codegen.insChain; + tensor = env.insChain; rewriter.replaceOpWithNewOp(op, resType, tensor, hasInserts); } else { // To rematerialize an non-annotated tensor, simply load it // from the bufferized value. - Value val = codegen.loopEmitter.getValBuffer().back(); // value array + Value val = env.getValBuffer().back(); // value array rewriter.replaceOpWithNewOp(op, resType, val); } } @@ -1656,6 +1666,7 @@ //===----------------------------------------------------------------------===// namespace { + /// Sparse rewriting rule for generic Lingalg operation. struct GenericOpSparsifier : public OpRewritePattern { public: @@ -1664,86 +1675,84 @@ LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { - // Detects sparse annotations and translate the per-dimension sparsity - // information for all tensors to loop indices in the kernel. + // Only accept single output operations. if (op.getNumDpsInits() != 1) return failure(); + + // Sets up a code generation environment. unsigned numTensors = op->getNumOperands(); unsigned numLoops = op.getNumLoops(); unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op); - Merger merger(numTensors, numLoops, numFilterLoops); - if (!findSparseAnnotations(merger, op)) + CodeGenEnv env(op, options, numTensors, numLoops, numFilterLoops); + + // Detects sparse annotations and translates the per-dimension sparsity + // information for all tensors to loop indices in the kernel. + if (!findSparseAnnotations(env)) return failure(); // Builds the tensor expression for the Linalg operation in SSA form. - Optional optExp = merger.buildTensorExpFromLinalg(op); + Optional optExp = env.merger.buildTensorExpFromLinalg(op); if (!optExp.has_value()) return failure(); - unsigned exp = *optExp; - OpOperand *sparseOut = nullptr; - unsigned outerParNest = 0; + // Computes a topologically sorted iteration graph to ensure tensors // are visited in natural index order. Gradually relaxes the considered // constraints until an acyclic iteration graph results, such that sparse // code generation can proceed. As a last resort, an attempt is made // to resolve cycles by inserting a conversion. - std::vector topSort; - // Whether the current GenericOp is admissible. bool isAdmissible = false; bool hasCycle = true; // An const list of all masks that we used for interation graph - // computation. Must be ordered from strict -> loose. + // computation. Must be ordered from more strict to less strict. const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, SortMask::kIncludeDense, SortMask::kSparseOnly}; for (auto mask : allMask) - if (computeIterationGraph(merger, op, topSort, mask)) { + if (computeIterationGraph(env, mask)) { hasCycle = false; - if (isAdmissibleTensorExp(merger, op, topSort, exp, &sparseOut, - outerParNest)) { + if (isAdmissibleTensorExp(env, exp)) { isAdmissible = true; break; } // else try a set of less strict constraints. } - if (hasCycle) - // Give it one last shot to resolve the cycle. - return resolveCycle(merger, rewriter, op); + return resolveCycle(env, rewriter); // one last shot if (!isAdmissible) - // Inadmissible expression, reject. - return failure(); - - merger.setHasSparseOut(sparseOut != nullptr); + return failure(); // inadmissible expression, reject + // Updates environment with a loop emitter. + // TODO: refactor so that emitter can be constructed earlier + // and updating is made easy, i.e. remove this whole block? SmallVector tensors; for (OpOperand &t : op->getOpOperands()) tensors.push_back(t.get()); + SparseTensorLoopEmitter lpe( + tensors, + StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()), + /*hasOutput=*/true, /*isSparseOut=*/env.sparseOut != nullptr, + env.topSort); + env.startEmit(&lpe); // Recursively generates code if admissible. - CodeGen codegen(options, op.getContext(), tensors, numTensors, numLoops, - sparseOut, outerParNest, topSort); - genBuffers(merger, codegen, rewriter, op); - genInitConstantDenseAddress(codegen, rewriter, op); - genStmt(merger, codegen, rewriter, op, exp, 0); - genResult(merger, codegen, rewriter, op); + genBuffers(env, rewriter); + genInitConstantDenseAddress(env, rewriter); + genStmt(env, rewriter, exp, 0); + genResult(env, rewriter); return success(); } private: // Last resort cycle resolution. - LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter, - linalg::GenericOp op) const { + LogicalResult resolveCycle(CodeGenEnv &env, PatternRewriter &rewriter) const { // Compute topological sort while leaving out every // sparse input tensor in succession until an acylic // iteration graph results. - std::vector topSort; - for (OpOperand *t : op.getDpsInputOperands()) { + for (OpOperand *t : env.linalgOp.getDpsInputOperands()) { unsigned tensor = t->getOperandNumber(); Value tval = t->get(); auto srcEnc = getSparseTensorEncoding(tval.getType()); - if (!srcEnc || - !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t)) + if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t)) continue; // Found an input tensor that resolves the cycle by inserting a // conversion into a sparse tensor that adheres to the iteration @@ -1754,16 +1763,15 @@ // auto srcTp = tval.getType().cast(); auto dstEnc = SparseTensorEncodingAttr::get( - op->getContext(), srcEnc.getDimLevelType(), - permute(merger, getContext(), op.getMatchingIndexingMap(t), - topSort), // new order + getContext(), srcEnc.getDimLevelType(), + permute(env, env.linalgOp.getMatchingIndexingMap(t)), // new order srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); auto dstTp = RankedTensorType::get(srcTp.getShape(), srcTp.getElementType(), dstEnc); auto convert = rewriter.create(tval.getLoc(), dstTp, tval); - op->setOperand(tensor, convert); - rewriter.setInsertionPointAfter(op); + env.linalgOp->setOperand(tensor, convert); + rewriter.setInsertionPointAfter(env.linalgOp); rewriter.create(tval.getLoc(), convert); return success(); }