diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -282,6 +282,7 @@ /// Constructs a new iteration lattice point, and returns its identifier. LatPointId addLat(TensorId t, LoopId i, ExprId e); + LatPointId addLat(const BitVector &bits, ExprId e); /// Constructs a new (initially empty) set, and returns its identifier. LatSetId addSet(); @@ -449,12 +450,58 @@ void setHasSparseOut(bool s) { hasSparseOut = s; } /// Convenience getters to immediately access the stored nodes. - /// Typically it is inadvisible to keep the reference around, as in - /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger - /// may cause data movement and invalidate the underlying memory address. - TensorExp &exp(ExprId e) { return tensorExps[e]; } - LatPoint &lat(LatPointId p) { return latPoints[p]; } - SmallVector &set(LatSetId s) { return latSets[s]; } + /// These methods return `const&` because the underlying objects must + /// not be mutated by client code. The only exception is for mutating + /// the value associated with an expression, for which there are + /// dedicated methods below. + /// + /// NOTE: It is inadvisable to keep the reference alive for a long + /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions + /// into the merger can cause data movement which will invalidate the + /// underlying memory address. This isn't just a problem with the `&` + /// references, but also applies to the `ArrayRef`. In particular, + /// using `for (LatPointId p : merger.set(s))` will run into the same + /// dangling-reference problems if the loop body does insertions into + /// the merger. + const TensorExp &exp(ExprId e) const { return tensorExps[e]; } + const LatPoint &lat(LatPointId p) const { return latPoints[p]; } + // FIXME: Once we return a new `LatSetId` to client code, we never + // mutate the set again. So is there some other LLVM datatype which + // would be stable for use in foreach loops? (Or a better way to + // store the `latSets` field which would enable stability?) + ArrayRef set(LatSetId s) const { return latSets[s]; } + + /// Checks whether the given expression has an associated value. + bool hasExprValue(ExprId e) const { + return static_cast(tensorExps[e].val); + } + + /// Sets the expression to have the associated value. Asserts that + /// the new value is defined, and that the expression does not already + /// have a value. If you want to overwrite a previous associated value, + /// use `updateExprValue` instead. + void setExprValue(ExprId e, Value v) { + assert(v && "Got an undefined value"); + auto &val = tensorExps[e].val; + assert(!val && "Expression already has an associated value"); + val = v; + } + + /// Clears the value associated with the expression. Asserts that the + /// expression does indeed have an associated value before clearing it. + /// If you don't want to check for a previous associated value first, + /// then use `updateExprValue` instead. + void clearExprValue(ExprId e) { + auto &val = tensorExps[e].val; + assert(val && "Expression does not have an associated value to clear"); + val = Value(); + } + + /// Unilaterally updates the expression to have the associated value. + /// That is, unlike `setExprValue` and `clearExprValue`, this method + /// does not perform any checks on whether the expression had a + /// previously associated value nor whether the new value is defined. + void updateExprValue(ExprId e, Value v) { tensorExps[e].val = v; } #ifndef NDEBUG /// Print methods (for debugging). diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -65,9 +65,9 @@ // Merger delegates. // - TensorExp &exp(ExprId e) { return latticeMerger.exp(e); } - LatPoint &lat(LatPointId l) { return latticeMerger.lat(l); } - SmallVector &set(LatSetId s) { return latticeMerger.set(s); } + const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); } + const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); } + ArrayRef set(LatSetId s) const { return latticeMerger.set(s); } DimLevelType dlt(TensorId t, LoopId i) const { return latticeMerger.getDimLevelType(t, i); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -86,6 +86,9 @@ auto r = callback(params); // may update parameters unsigned i = 0; if (isReduc()) { + // FIXME: This requires `updateExprValue` to perform updates without + // checking for a previous value; but it's not clear whether that's + // by design or might be a potential source for bugs. updateReduc(params[i++]); if (redValidLexInsert) setValidLexInsert(params[i++]); @@ -237,12 +240,18 @@ void CodegenEnv::updateReduc(Value val) { assert(isReduc()); - redVal = exp(redExp).val = val; + redVal = val; + // NOTE: `genLoopBoundary` requires that this performs a unilateral + // update without checking for a previous value first. (It's not + // clear whether any other callsites also require that.) + latticeMerger.updateExprValue(redExp, val); } Value CodegenEnv::endReduc() { + assert(isReduc()); Value val = redVal; - updateReduc(Value()); + redVal = val; + latticeMerger.clearExprValue(redExp); redExp = kInvalidId; return val; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -864,7 +864,7 @@ assert(env.exp(exp).val); Value v0 = env.exp(exp).val; genInsertionStore(env, builder, t, v0); - env.exp(exp).val = Value(); + env.merger().clearExprValue(exp); // Yield modified insertion chain along true branch. Value mchain = env.getInsertionChain(); builder.create(op.getLoc(), mchain); @@ -943,10 +943,8 @@ if (kind == Kind::kReduce) env.endCustomReduc(); // exit custom - if (kind == kSelect) { - assert(!exp.val); - env.exp(e).val = v0; // Preserve value for later use. - } + if (kind == kSelect) + env.merger().setExprValue(e, v0); // Preserve value for later use. return ee; } @@ -998,7 +996,10 @@ } } else { // Start or end loop invariant hoisting of a tensor load. - env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value(); + if (atStart) + env.merger().setExprValue(exp, genTensorLoad(env, builder, exp)); + else + env.merger().clearExprValue(exp); } } else if (env.exp(exp).kind != Kind::kInvariant && env.exp(exp).kind != Kind::kLoopVar) { @@ -1152,8 +1153,7 @@ /// Generates the induction structure for a while-loop. static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, - bool needsUniv, BitVector &induction, - scf::WhileOp whileOp) { + bool needsUniv, scf::WhileOp whileOp) { Location loc = env.op().getLoc(); // Finalize each else branch of all if statements. if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { @@ -1192,7 +1192,7 @@ /// Generates a single if-statement within a while-loop. static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, - BitVector &conditions) { + const BitVector &conditions) { Location loc = env.op().getLoc(); SmallVector types; Value cond; @@ -1291,12 +1291,9 @@ // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. if (needsUniv) { - unsigned lsize = env.set(lts).size(); - for (unsigned i = 1; i < lsize; i++) { - const LatPointId li = env.set(lts)[i]; + for (const LatPointId li : env.set(lts).drop_front()) if (!env.merger().hasAnySparse(env.lat(li).simple)) return true; - } } return false; } @@ -1471,7 +1468,7 @@ LoopId idx, LatPointId li, bool needsUniv) { // End a while-loop. if (auto whileOp = dyn_cast(loop)) { - finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp); + finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp); } else if (auto forOp = dyn_cast(loop)) { // Any iteration of a reduction for-loop creates a valid lex insert. if (env.isReduc() && env.getValidLexInsert()) @@ -1522,10 +1519,14 @@ bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts); // Emit a loop for every lattice point L0 >= Li in this loop sequence. - unsigned lsize = env.set(lts).size(); + // + // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))` + // because the loop body causes data-movement which invalidates + // the iterator. + const unsigned lsize = env.set(lts).size(); for (unsigned i = 0; i < lsize; i++) { - // Start a loop. const LatPointId li = env.set(lts)[i]; + // Start a loop. Operation *loop = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the @@ -1533,7 +1534,10 @@ Value redInput = env.getReduc(); Value cntInput = env.getExpandCount(); Value insInput = env.getInsertionChain(); - bool isWhile = dyn_cast(loop) != nullptr; + const bool isWhile = isa(loop); + // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))` + // because the loop body causes data-movement which invalidates the + // iterator. for (unsigned j = 0; j < lsize; j++) { const LatPointId lj = env.set(lts)[j]; const ExprId ej = env.lat(lj).exp; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -240,6 +240,13 @@ return p; } +LatPointId Merger::addLat(const BitVector &bits, ExprId e) { + assert(bits.size() == numLoops * numTensors); + const LatPointId p = latPoints.size(); + latPoints.emplace_back(bits, e); + return p; +} + LatSetId Merger::addSet() { const LatSetId s = latSets.size(); latSets.emplace_back(); @@ -310,8 +317,7 @@ const LatSetId s = addSet(); for (const LatPointId p : latSets[s0]) { const ExprId e = addExp(kind, latPoints[p].exp, v, op); - latPoints.emplace_back(latPoints[p].bits, e); - latSets[s].push_back(latPoints.size() - 1); + latSets[s].push_back(addLat(latPoints[p].bits, e)); } return s; }