diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -76,20 +76,22 @@ /// for the corresponding `SmallVector` object. using LatSetId = unsigned; +namespace detail { /// A constant serving as the canonically invalid identifier, regardless /// of the identifier type. static constexpr unsigned kInvalidId = -1u; +} // namespace detail -/// Children subexpressions of tensor operations. -struct Children { - ExprId e0; - ExprId e1; -}; - -/// Tensor expression. Represents a MLIR expression in tensor index notation. -struct TensorExp { +/// Tensor expression. Represents an MLIR expression in tensor index notation. +struct TensorExp final { enum class Kind; + /// Child subexpressions for non-leaf expressions. + struct Children final { + ExprId e0; + ExprId e1; + }; + // The `x` parameter has different types depending on the value of the // `k` parameter. The correspondences are: // * `kTensor` -> `TensorId` @@ -209,7 +211,7 @@ /// of `TensorLoopId`s, together with the identifier of the corresponding /// tensor expression. The formal conjunction is represented as a set of /// `TensorLoopId`, where that set is implemented as a `BitVector`. -struct LatPoint { +struct LatPoint final { /// Construct the lattice point from a given set of `TensorLoopId`s. LatPoint(const BitVector &bits, ExprId e); @@ -269,17 +271,18 @@ /// Constructs a new tensor expression, and returns its identifier. /// The type of the `e0` argument varies according to the value of the /// `k` argument, as described by the `TensorExp` ctor. - ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = kInvalidId, + ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = detail::kInvalidId, Value v = Value(), Operation *op = nullptr); ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) { - return addExp(k, e, kInvalidId, v, op); + return addExp(k, e, detail::kInvalidId, v, op); } ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) { - return addExp(k, kInvalidId, kInvalidId, v, op); + return addExp(k, detail::kInvalidId, detail::kInvalidId, v, op); } /// Constructs a new iteration lattice point, and returns its identifier. LatPointId addLat(TensorId t, LoopId i, ExprId e); + LatPointId addLat(const BitVector &bits, ExprId e); /// Constructs a new (initially empty) set, and returns its identifier. LatSetId addSet(); @@ -494,12 +497,60 @@ } /// Convenience getters to immediately access the stored nodes. - /// Typically it is inadvisible to keep the reference around, as in - /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger - /// may cause data movement and invalidate the underlying memory address. - TensorExp &exp(ExprId e) { return tensorExps[e]; } - LatPoint &lat(LatPointId p) { return latPoints[p]; } - SmallVector &set(LatSetId s) { return latSets[s]; } + /// These methods return `const&` because the underlying objects must + /// not be mutated by client code. The only exception is for mutating + /// the value associated with an expression, for which there are + /// dedicated methods below. + /// + /// NOTE: It is inadvisable to keep the reference alive for a long + /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions + /// into the merger can cause data movement which will invalidate the + /// underlying memory address. This isn't just a problem with the `&` + /// references, but also applies to the `ArrayRef`. In particular, + /// using `for (LatPointId p : merger.set(s))` will run into the same + /// dangling-reference problems if the loop body inserts new sets. + const TensorExp &exp(ExprId e) const { return tensorExps[e]; } + const LatPoint &lat(LatPointId p) const { return latPoints[p]; } + ArrayRef set(LatSetId s) const { return latSets[s]; } + + /// Checks whether the given expression has an associated value. + bool hasExprValue(ExprId e) const { + return static_cast(tensorExps[e].val); + } + + /// Sets the expression to have the associated value. Asserts that + /// the new value is defined, and that the expression does not already + /// have a value. If you want to overwrite a previous associated value, + /// use `updateExprValue` instead. + void setExprValue(ExprId e, Value v) { + assert(v && "Got an undefined value"); + auto &val = tensorExps[e].val; + assert(!val && "Expression already has an associated value"); + val = v; + } + + /// Clears the value associated with the expression. Asserts that the + /// expression does indeed have an associated value before clearing it. + /// If you don't want to check for a previous associated value first, + /// then use `updateExprValue` instead. + void clearExprValue(ExprId e) { + auto &val = tensorExps[e].val; + assert(val && "Expression does not have an associated value to clear"); + val = Value(); + } + + /// Unilaterally updates the expression to have the associated value. + /// That is, unlike `setExprValue` and `clearExprValue`, this method + /// does not perform any checks on whether the expression had a + /// previously associated value nor whether the new value is defined. + // + // TODO: The unilateral update semantics are required by the + // current implementation of `CodegenEnv::genLoopBoundary`; however, + // that implementation seems a bit dubious. We would much rather have + // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or + // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those + // provide better invariants. + void updateExprValue(ExprId e, Value v) { tensorExps[e].val = v; } #ifndef NDEBUG /// Print methods (for debugging). diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -65,9 +65,9 @@ // Merger delegates. // - TensorExp &exp(ExprId e) { return latticeMerger.exp(e); } - LatPoint &lat(LatPointId l) { return latticeMerger.lat(l); } - SmallVector &set(LatSetId s) { return latticeMerger.set(s); } + const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); } + const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); } + ArrayRef set(LatSetId s) const { return latticeMerger.set(s); } DimLevelType dlt(TensorId t, LoopId i) const { return latticeMerger.getDimLevelType(t, i); } @@ -133,7 +133,7 @@ // void startReduc(ExprId exp, Value val); - bool isReduc() const { return redExp != kInvalidId; } + bool isReduc() const { return redExp != detail::kInvalidId; } void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); @@ -142,7 +142,7 @@ Value getValidLexInsert() const { return redValidLexInsert; } void startCustomReduc(ExprId exp); - bool isCustomReduc() const { return redCustom != kInvalidId; } + bool isCustomReduc() const { return redCustom != detail::kInvalidId; } Value getCustomRedId(); void endCustomReduc(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -55,8 +55,8 @@ : linalgOp(linop), sparseOptions(opts), latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), - expFilled(), expAdded(), expCount(), redVal(), redExp(kInvalidId), - redCustom(kInvalidId), redValidLexInsert() {} + expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId), + redCustom(detail::kInvalidId), redValidLexInsert() {} LogicalResult CodegenEnv::initTensorExp() { // Builds the tensor expression for the Linalg operation in SSA form. @@ -130,6 +130,9 @@ auto r = callback(params); // may update parameters unsigned i = 0; if (isReduc()) { + // FIXME: This requires `updateExprValue` to perform updates without + // checking for a previous value; but it's not clear whether that's + // by design or might be a potential source for bugs. updateReduc(params[i++]); if (redValidLexInsert) setValidLexInsert(params[i++]); @@ -274,20 +277,26 @@ //===----------------------------------------------------------------------===// void CodegenEnv::startReduc(ExprId exp, Value val) { - assert(!isReduc() && exp != kInvalidId); + assert(!isReduc() && exp != detail::kInvalidId); redExp = exp; updateReduc(val); } void CodegenEnv::updateReduc(Value val) { assert(isReduc()); - redVal = exp(redExp).val = val; + redVal = val; + // NOTE: `genLoopBoundary` requires that this performs a unilateral + // update without checking for a previous value first. (It's not + // clear whether any other callsites also require that.) + latticeMerger.updateExprValue(redExp, val); } Value CodegenEnv::endReduc() { + assert(isReduc()); Value val = redVal; - updateReduc(Value()); - redExp = kInvalidId; + redVal = val; + latticeMerger.clearExprValue(redExp); + redExp = detail::kInvalidId; return val; } @@ -302,7 +311,7 @@ } void CodegenEnv::startCustomReduc(ExprId exp) { - assert(!isCustomReduc() && exp != kInvalidId); + assert(!isCustomReduc() && exp != detail::kInvalidId); redCustom = exp; } @@ -313,5 +322,5 @@ void CodegenEnv::endCustomReduc() { assert(isCustomReduc()); - redCustom = kInvalidId; + redCustom = detail::kInvalidId; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -265,6 +265,10 @@ return isOutputTensor(tid) && isSparseOut; } + bool isValidLevel(TensorId tid, Level lvl) const { + return tid < lvlTypes.size() && lvl < lvlTypes[tid].size(); + } + /// Prepares loop for iterating over `tensor[lvl]`, under the assumption /// that `tensor[0...lvl-1]` loops have already been set up. void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -450,7 +450,7 @@ for (auto [t, l] : llvm::zip(tids, lvls)) { // TODO: this check for validity of the (t,l) pairs should be // checked/enforced at the callsites, if possible. - assert(t < lvlTypes.size() && l < lvlTypes[t].size()); + assert(isValidLevel(t, l)); assert(!coords[t][l]); // We cannot re-enter the same level const auto lvlTp = lvlTypes[t][l]; const bool isSparse = isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp); @@ -566,7 +566,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl( OpBuilder &builder, Location loc, TensorId tid, Level lvl, AffineExpr affine, MutableArrayRef reduc) { - assert(tid < lvlTypes.size() && lvl < lvlTypes[tid].size()); + assert(isValidLevel(tid, lvl)); assert(!affine.isa() && !isDenseDLT(lvlTypes[tid][lvl])); // We can not re-enter the same level. assert(!coords[tid][lvl]); @@ -856,7 +856,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level dstLvl) { - assert(tid < lvlTypes.size() && dstLvl < lvlTypes[tid].size()); + assert(isValidLevel(tid, dstLvl)); const auto lvlTp = lvlTypes[tid][dstLvl]; if (isDenseDLT(lvlTp)) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1057,7 +1057,7 @@ assert(env.exp(exp).val); Value v0 = env.exp(exp).val; genInsertionStore(env, builder, t, v0); - env.exp(exp).val = Value(); + env.merger().clearExprValue(exp); // Yield modified insertion chain along true branch. Value mchain = env.getInsertionChain(); builder.create(op.getLoc(), mchain); @@ -1111,7 +1111,7 @@ linalg::GenericOp op = env.op(); Location loc = op.getLoc(); - if (e == kInvalidId) + if (e == ::mlir::sparse_tensor::detail::kInvalidId) return Value(); const TensorExp &exp = env.exp(e); const auto kind = exp.kind; @@ -1137,10 +1137,8 @@ if (kind == TensorExp::Kind::kReduce) env.endCustomReduc(); // exit custom - if (kind == TensorExp::Kind::kSelect) { - assert(!exp.val); - env.exp(e).val = v0; // Preserve value for later use. - } + if (kind == TensorExp::Kind::kSelect) + env.merger().setExprValue(e, v0); // Preserve value for later use. return ee; } @@ -1148,11 +1146,11 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId ldx, bool atStart) { - if (exp == kInvalidId) + if (exp == ::mlir::sparse_tensor::detail::kInvalidId) return; if (env.exp(exp).kind == TensorExp::Kind::kTensor) { // Inspect tensor indices. - bool isAtLoop = ldx == kInvalidId; + bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId; linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); auto map = op.getMatchingIndexingMap(&t); @@ -1192,7 +1190,10 @@ } } else { // Start or end loop invariant hoisting of a tensor load. - env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value(); + if (atStart) + env.merger().setExprValue(exp, genTensorLoad(env, builder, exp)); + else + env.merger().clearExprValue(exp); } } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant && env.exp(exp).kind != TensorExp::Kind::kLoopVar) { @@ -1346,8 +1347,7 @@ /// Generates the induction structure for a while-loop. static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, - bool needsUniv, BitVector &induction, - scf::WhileOp whileOp) { + bool needsUniv, scf::WhileOp whileOp) { Location loc = env.op().getLoc(); // Finalize each else branch of all if statements. if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { @@ -1386,7 +1386,7 @@ /// Generates a single if-statement within a while-loop. static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, - BitVector &conditions) { + const BitVector &conditions) { Location loc = env.op().getLoc(); SmallVector types; Value cond; @@ -1486,13 +1486,10 @@ // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. if (needsUniv) { - unsigned lsize = env.set(lts).size(); - for (unsigned i = 1; i < lsize; i++) { - const LatPointId li = env.set(lts)[i]; + for (const LatPointId li : env.set(lts).drop_front()) if (!env.merger().hasAnySparse(env.lat(li).simple) && !env.merger().hasSparseIdxReduction(env.lat(li).simple)) return true; - } } return false; } @@ -1675,7 +1672,7 @@ LoopId idx, LatPointId li, bool needsUniv) { // End a while-loop. if (auto whileOp = dyn_cast(loop)) { - finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp); + finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp); } else if (auto forOp = dyn_cast(loop)) { // Any iteration of a reduction for-loop creates a valid lex insert. if (env.isReduc() && env.getValidLexInsert()) @@ -1718,7 +1715,8 @@ // Construct iteration lattices for current loop index, with L0 at top. const LoopId idx = env.topSortAt(at); - const LoopId ldx = at == 0 ? kInvalidId : env.topSortAt(at - 1); + const LoopId ldx = at == 0 ? ::mlir::sparse_tensor::detail::kInvalidId + : env.topSortAt(at - 1); const LatSetId lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx)); @@ -1726,10 +1724,14 @@ bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts); // Emit a loop for every lattice point L0 >= Li in this loop sequence. - unsigned lsize = env.set(lts).size(); + // + // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))` + // because the loop body causes data-movement which invalidates + // the iterator. + const unsigned lsize = env.set(lts).size(); for (unsigned i = 0; i < lsize; i++) { - // Start a loop. const LatPointId li = env.set(lts)[i]; + // Start a loop. auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the @@ -1737,6 +1739,9 @@ Value redInput = env.getReduc(); Value cntInput = env.getExpandCount(); Value insInput = env.getInsertionChain(); + // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))` + // because the loop body causes data-movement which invalidates the + // iterator. for (unsigned j = 0; j < lsize; j++) { const LatPointId lj = env.set(lts)[j]; const ExprId ej = env.lat(lj).exp; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -103,16 +103,16 @@ switch (kind) { // Leaf. case TensorExp::Kind::kTensor: - assert(x != kInvalidId && y == kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); tensor = x; - break; + return; case TensorExp::Kind::kInvariant: - assert(x == kInvalidId && y == kInvalidId && v && !o); - break; + assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); + return; case TensorExp::Kind::kLoopVar: - assert(x != kInvalidId && y == kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); loop = x; - break; + return; // Unary operations. case TensorExp::Kind::kAbsF: case TensorExp::Kind::kAbsC: @@ -134,10 +134,10 @@ case TensorExp::Kind::kNegI: case TensorExp::Kind::kCIm: case TensorExp::Kind::kCRe: - assert(x != kInvalidId && y == kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); children.e0 = x; children.e1 = y; - break; + return; case TensorExp::Kind::kTruncF: case TensorExp::Kind::kExtF: case TensorExp::Kind::kCastFS: @@ -149,23 +149,23 @@ case TensorExp::Kind::kCastIdx: case TensorExp::Kind::kTruncI: case TensorExp::Kind::kBitCast: - assert(x != kInvalidId && y == kInvalidId && v && !o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); children.e0 = x; children.e1 = y; - break; + return; case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kSelect: - assert(x != kInvalidId && y == kInvalidId && !v && o); + assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); children.e0 = x; children.e1 = y; - break; + return; case TensorExp::Kind::kUnary: // No assertion on y can be made, as the branching paths involve both // a unary (`mapSet`) and binary (`disjSet`) pathway. - assert(x != kInvalidId && !v && o); + assert(x != detail::kInvalidId && !v && o); children.e0 = x; children.e1 = y; - break; + return; // Binary operations. case TensorExp::Kind::kMulF: case TensorExp::Kind::kMulC: @@ -186,17 +186,18 @@ case TensorExp::Kind::kShrS: case TensorExp::Kind::kShrU: case TensorExp::Kind::kShlI: - assert(x != kInvalidId && y != kInvalidId && !v && !o); + assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); children.e0 = x; children.e1 = y; - break; + return; case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: - assert(x != kInvalidId && y != kInvalidId && !v && o); + assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); children.e0 = x; children.e1 = y; - break; + return; } + llvm_unreachable("unexpected kind"); } LatPoint::LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {} @@ -247,6 +248,13 @@ return p; } +LatPointId Merger::addLat(const BitVector &bits, ExprId e) { + assert(bits.size() == numLoops * numTensors); + const LatPointId p = latPoints.size(); + latPoints.emplace_back(bits, e); + return p; +} + LatSetId Merger::addSet() { const LatSetId s = latSets.size(); latSets.emplace_back(); @@ -322,8 +330,7 @@ const LatSetId s = addSet(); for (const LatPointId p : latSets[s0]) { const ExprId e = addExp(kind, latPoints[p].exp, v, op); - latPoints.emplace_back(latPoints[p].bits, e); - latSets[s].push_back(latPoints.size() - 1); + latSets[s].push_back(addLat(latPoints[p].bits, e)); } return s; }