diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -113,9 +113,9 @@ } /// Helper method to inspect affine expressions. Rejects cases where the -/// same index is used in more than one dimension of a tensor. Also rejects -/// affine expressions that are not a direct index for annotated tensors. -/// TODO: accept more affine cases for sparse tensors +/// same index is used more than once. Also rejects affine expressions +/// that are not a direct index for annotated tensors. +// TODO: accept more affine cases for sparse tensors static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, bool isDense) { switch (a.getKind()) { @@ -263,6 +263,22 @@ return true; } +/// Returns true if tensor has an in-place annotation. +static bool isInPlace(Value val) { + if (auto arg = val.dyn_cast()) + if (auto funcOp = dyn_cast(arg.getOwner()->getParentOp())) + if (auto attr = funcOp.getArgAttrOfType( + arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) + return attr.getValue(); + return false; +} + +/// Returns true if tensor materializes into the computation. +static bool isMaterializing(Value val) { + return val.getDefiningOp() || + val.getDefiningOp(); +} + /// Returns true when the tensor expression is admissable for codegen. /// Since all sparse input tensors are admissable, we just need to check /// whether the output tensor in the tensor expression codegen is admissable. @@ -288,16 +304,17 @@ return true; // A tensor expression with a sparse output tensor that changes its values // but not its nonzero structure, an operation called "simply dynamic" in - // [Bik96,Ch9], is also admissable without special codegen. + // [Bik96,Ch9], is also admissable without special codegen, provided + // the tensor's underlying sparse storage scheme can be modified in place. if (merger.isConjunction(tensor, exp)) - return true; + return isInPlace(lhs->get()); // Reject for now since this requires changes to the nonzero structure. // TODO: implement "workspaces" [Kjolstad2019] return false; } //===----------------------------------------------------------------------===// -// Sparse compiler synthesis methods. +// Sparse compiler synthesis methods (statements and expressions). //===----------------------------------------------------------------------===// /// Maps reduction kind to name encoding. @@ -350,7 +367,7 @@ case kXor: { // Initialize reduction vector to: | 0 | .. | 0 | r | Attribute zero = rewriter.getZeroAttr(vtp); - Value vec = rewriter.create(loc, vtp, zero); + Value vec = rewriter.create(loc, vtp, zero); return rewriter.create(loc, r, vec, 0); } case kProduct: { @@ -361,8 +378,8 @@ one = rewriter.getFloatAttr(etp, 1.0); else one = rewriter.getIntegerAttr(etp, 1); - Value vec = - rewriter.create(loc, vtp, DenseElementsAttr::get(vtp, one)); + Value vec = rewriter.create( + loc, vtp, DenseElementsAttr::get(vtp, one)); return rewriter.create(loc, r, vec, 0); } case kAnd: @@ -380,16 +397,6 @@ return rewriter.getIntegerType(width); } -/// Detects in-place annotation on tensor argument. -static bool getInPlace(Value val) { - if (auto arg = val.dyn_cast()) - if (auto funcOp = dyn_cast(arg.getOwner()->getParentOp())) - if (auto attr = funcOp.getArgAttrOfType( - arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) - return attr.getValue(); - return false; -} - /// Generates buffer for the output tensor. Note that all sparse kernels /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), /// the output buffer is already initialized to all zeroes and only nonzeroes @@ -405,18 +412,19 @@ // be generated for the tensor present in the outs() clause. This has // the major advantage that the sparse kernel only updates the nonzero // positions for the output tensor. - if (getInPlace(tensor)) + if (isInPlace(tensor)) return rewriter.create(loc, denseTp, tensor); // By default, a new buffer is allocated which is initialized to the // tensor defined in the outs() clause. This is always correct but // introduces a dense initialization component that may negatively // impact the running complexity of the sparse kernel. If the tensor - // materializes within this method, we need to preserve the zero + // materializes into the computation, we need to preserve the zero // initialization assumption of all sparse output buffers. - if (auto init = tensor.getDefiningOp()) { + if (isMaterializing(tensor)) { Type tp = denseTp.getElementType(); Value alloc = rewriter.create(loc, denseTp, args); - Value zero = rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + Value zero = + rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); rewriter.create(loc, zero, alloc); return alloc; } @@ -429,7 +437,7 @@ /// Local bufferization of all dense and sparse data structures. /// This code enables testing the first prototype sparse compiler. // TODO: replace this with a proliferated bufferization strategy -static bool genBuffers(Merger &merger, CodeGen &codegen, +static void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op) { Location loc = op.getLoc(); assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); @@ -486,15 +494,12 @@ genOutputBuffer(codegen, rewriter, op, denseTp, args); } else { // Annotated sparse tensors. - if (tensor == op.getNumInputs() && !getInPlace(t->get())) - return false; // reject output if not in-place auto dynShape = {ShapedType::kDynamicSize}; auto sparseTp = MemRefType::get(dynShape, elementType); codegen.buffers[tensor] = rewriter.create(loc, sparseTp, t->get()); } } - return true; } /// Constructs vector type. @@ -623,7 +628,9 @@ if (enc) { // Note that currently, all sparse subscripts are simple. // TODO: accept affine too? - unsigned idx = map.getDimPosition(perm(enc, rank - 1)); + AffineExpr a = map.getResult(perm(enc, rank - 1)); + assert(a.getKind() == AffineExprKind::DimId); + unsigned idx = a.cast().getPosition(); assert(codegen.pidxs[tensor][idx] != nullptr); args.push_back(codegen.pidxs[tensor][idx]); // position index } else { @@ -841,6 +848,7 @@ if (lhs == t) { codegen.redExp = hoist ? exp : -1u; codegen.redKind = getReduction(last); + assert(!codegen.redVal); } else if (atLevel) { merger.exp(exp).val = hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); @@ -948,6 +956,7 @@ // dimension (true non-unit stride) or if the innermost index appears // in a compound subscript in the innermost dimension. Even if the // latter is unit stride, it does not play well with scatter/gather. + // TODO: accept unit stride affine innermost like a[i,j+k+1]? if (a.isFunctionOfDim(idx) && ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) return false; @@ -1209,6 +1218,83 @@ return ifOp; } +//===----------------------------------------------------------------------===// +// Sparse compiler synthesis methods (loop sequence). +//===----------------------------------------------------------------------===// + +/// Starts a loop sequence at given level. Returns true if +/// the universal loop index must be maintained at this level. +static bool startLoopSeq(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + std::vector &topSort, unsigned exp, + unsigned at, unsigned idx, unsigned ldx, + unsigned lts) { + assert(codegen.curVecLength == 1); + // Emit invariants at this loop sequence level. + genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); + // Emit further intitialization at this loop sequence level. + unsigned l0 = merger.set(lts)[0]; + if (genInit(merger, codegen, rewriter, op, topSort, at, + merger.lat(l0).bits)) { + // Maintain the universal index only if it is actually + // consumed by a subsequent lattice point. + unsigned lsize = merger.set(lts).size(); + for (unsigned i = 1; i < lsize; i++) { + unsigned li = merger.set(lts)[i]; + if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) + return true; + } + } + return false; +} + +/// Starts a single loop in current sequence. +static Operation *startLoop(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + std::vector &topSort, unsigned at, + unsigned li, bool needsUniv) { + assert(codegen.curVecLength == 1); + // Emit the for/while-loop control. + Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, + needsUniv, merger.lat(li).simple); + // Emit the locals for this loop. + genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, + merger.lat(li).bits); + return loop; +} + +/// Ends a single loop in current sequence. Returns new values for needsUniv. +static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, + linalg::GenericOp op, Operation *loop, unsigned idx, + unsigned li, bool needsUniv) { + codegen.curVecLength = 1; + // End a while-loop. + if (auto whileOp = dyn_cast(loop)) { + rewriter.setInsertionPointToEnd(&whileOp.after().front()); + genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, + merger.lat(li).bits, whileOp.results()); + return needsUniv; + } + // End a for-loop. + if (codegen.redVal) { + rewriter.create(op.getLoc(), codegen.redVal); + codegen.redVal = loop->getResult(0); + } + return false; +} + +/// Ends a loop sequence at given level. +static void endLoopSeq(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + unsigned exp, unsigned idx, unsigned ldx) { + assert(codegen.curVecLength == 1); + // Finalize any pending reduction. + genReductionEnd(merger, codegen, rewriter, op); + // Unmark bookkeeping of invariants and loop index. + genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); + codegen.loops[idx] = Value(); +} + /// Recursively generates code while computing iteration lattices in order /// to manage the complexity of implementing co-iteration over unions /// and intersections of sparse iterations spaces. @@ -1221,45 +1307,23 @@ genTensorStore(merger, codegen, rewriter, op, rhs); return; } - assert(codegen.curVecLength == 1); // Construct iteration lattices for current loop index, with L0 at top. - // Then emit initialization code for the loop sequence at this level. - // We maintain the universal dense index if dense indices are still - // in play for a non-singleton loop sequence. - Location loc = op.getLoc(); unsigned idx = topSort[at]; - unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); - unsigned lsize = merger.set(lts).size(); - assert(lsize != 0); - unsigned l0 = merger.set(lts)[0]; unsigned ldx = at == 0 ? -1u : topSort[at - 1]; - genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); - bool needsUniv = false; - if (genInit(merger, codegen, rewriter, op, topSort, at, - merger.lat(l0).bits)) { - // Maintain the universal index only if it is actually - // consumed by a subsequent lattice point. - for (unsigned i = 1; i < lsize; i++) { - unsigned li = merger.set(lts)[i]; - if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { - needsUniv = true; - break; - } - } - } + unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); + + // Start a loop sequence. + bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, + idx, ldx, lts); - // Emit a loop for every lattice point L0 >= Li. + // Emit a loop for every lattice point L0 >= Li in this loop sequence. + unsigned lsize = merger.set(lts).size(); for (unsigned i = 0; i < lsize; i++) { + // Start a loop. unsigned li = merger.set(lts)[i]; - - // Emit loop. - codegen.curVecLength = 1; - llvm::BitVector indices = merger.lat(li).simple; Operation *loop = - genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); - genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, - merger.lat(li).bits); + startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. @@ -1280,27 +1344,14 @@ } } - // Wrap-up induction and restore insertion point. - if (isWhile) { - scf::WhileOp whileOp = cast(loop); - rewriter.setInsertionPointToEnd(&whileOp.after().front()); - genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, - merger.lat(li).bits, whileOp.results()); - } else { - needsUniv = false; - if (codegen.redVal) { - rewriter.create(loc, codegen.redVal); - codegen.redVal = loop->getResult(0); - } - } + // End a loop. + needsUniv = + endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); rewriter.setInsertionPointAfter(loop); } - // Wrap-up loop sequence. - codegen.curVecLength = 1; - genReductionEnd(merger, codegen, rewriter, op); - genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); - codegen.loops[idx] = Value(); + // End a loop sequence. + endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx); } /// Converts the result computed by the sparse kernel into the required form. @@ -1385,8 +1436,7 @@ // Recursively generates code. CodeGen codegen(options, numTensors, numLoops); - if (!genBuffers(merger, codegen, rewriter, op)) - return failure(); // could not bufferize + genBuffers(merger, codegen, rewriter, op); genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); genResult(merger, codegen, rewriter, op); return success();