diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -247,7 +247,7 @@ /// Constructs a new invariant expression, and returns its identifier. ExprId addInvariantExp(Value v); /// Constructs a new unary or binary expression, and returns its identifier. - ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = detail::kInvalidId, + ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = ExprId(), Operation *op = nullptr); /// Constructs a new sesquinary expression, and returns its identifier. /// Currently no sesquinary `Kind` allows specifying the `op`, but we @@ -497,15 +497,15 @@ /// dangling-reference problems if the loop body inserts new sets. const TensorExp &exp(ExprId e) const { assert(isValidExprId(e)); - return tensorExps[e]; + return tensorExps[e.value]; } const LatPoint &lat(LatPointId p) const { assert(isValidLatPointId(p)); - return latPoints[p]; + return latPoints[p.value]; } ArrayRef set(LatSetId s) const { assert(isValidLatSetId(s)); - return latSets[s]; + return latSets[s.value]; } /// Checks whether the given expression has an associated value. @@ -518,7 +518,7 @@ void setExprValue(ExprId e, Value v) { assert(isValidExprId(e)); assert(v && "Got an undefined value"); - auto &val = tensorExps[e].val; + auto &val = tensorExps[e.value].val; assert(!val && "Expression already has an associated value"); val = v; } @@ -529,7 +529,7 @@ /// then use `updateExprValue` instead. void clearExprValue(ExprId e) { assert(isValidExprId(e)); - auto &val = tensorExps[e].val; + auto &val = tensorExps[e.value].val; assert(val && "Expression does not have an associated value to clear"); val = Value(); } @@ -547,7 +547,7 @@ // provide better invariants. void updateExprValue(ExprId e, Value v) { assert(isValidExprId(e)); - tensorExps[e].val = v; + tensorExps[e.value].val = v; } #ifndef NDEBUG @@ -581,13 +581,13 @@ return isValidTensorId(t) && lvl < lvlToLoop[t].size(); } bool isValidExprId(ExprId e) const { - return e != detail::kInvalidId && e < tensorExps.size(); + return e.isValid() && e.value < tensorExps.size(); } bool isValidLatPointId(LatPointId p) const { - return p != detail::kInvalidId && p < latPoints.size(); + return p.isValid() && p.value < latPoints.size(); } bool isValidLatSetId(LatSetId s) const { - return s != detail::kInvalidId && s < latSets.size(); + return s.isValid() && s.value < latSets.size(); } bool maybeZero(ExprId e) const; bool isInvariant(ExprId e) const { diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h @@ -6,13 +6,23 @@ // //===----------------------------------------------------------------------===// // -// TODO: This header currently defines some typedefs to avoid confusion -// between several different things which are all represented as `unsigned`. -// Over the next few commits, these typedefs will be replaced with "newtypes" -// (i.e., data types which are zero-cost abstractions for wrapping some -// underlying type while ensuring that the compiler keeps the new type -// distinct from the old type), along with related classes for iterating -// over them, etc. +// This header defines a number of "newtypes" (i.e., data types which are +// zero-cost abstractions for wrapping some underlying type while ensuring +// that the compiler keeps the new type distinct from the old type), +// along with related classes for iterating over them, etc. +// +// To guarantee that we achieve the goal of being zero-cost abstractions, +// we assert `std::is_trivially_copyable` and `std::is_trivially_destructible` +// of all these newtypes. These predicates license the compiler to make +// several optimizations; some of which are explicitly documented by the +// C++ standard: +// +// +// However, some key optimizations aren't mentioned by the standard; e.g., +// that trivially-copyable enables passing-by-value, and the conjunction +// of trivially-copyable and trivially-destructible enables passing those +// values in registers rather than on the stack (cf., +// ). // //===----------------------------------------------------------------------===// @@ -78,18 +88,73 @@ //===----------------------------------------------------------------------===// /// `TensorExp` identifiers. These are allocated by `Merger::addExp`, /// and serve as unique identifiers for the corresponding `TensorExp` object. -using ExprId = unsigned; +class ExprId final { + friend class Merger; + friend class TensorExp; + explicit constexpr ExprId(unsigned value) : value(value) {} + +public: + /// Constructs a new expression identifier with a dedicated + /// known-invalid value. + explicit constexpr ExprId() : value(detail::kInvalidId) {} + /// Checks whether the expression identifier has the dedicated + /// known-invalid value. + constexpr bool isValid() const { return value != detail::kInvalidId; } + constexpr bool operator==(ExprId rhs) const { return value == rhs.value; } + constexpr bool operator!=(ExprId rhs) const { return value != rhs.value; } + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); //===----------------------------------------------------------------------===// /// `LatPoint` identifiers. These are allocated by `Merger::addLat`, /// and serve as unique identifiers for the corresponding `LatPoint` object. -using LatPointId = unsigned; +class LatPointId final { + friend class Merger; + explicit constexpr LatPointId(unsigned value) : value(value) {} + +public: + /// Constructs a new lattice-point identifier with a dedicated + /// known-invalid value. + explicit constexpr LatPointId() : value(detail::kInvalidId) {} + /// Checks whether the lattice-point identifier has the dedicated + /// known-invalid value. + constexpr bool isValid() const { return value != detail::kInvalidId; } + constexpr bool operator==(LatPointId rhs) const { return value == rhs.value; } + constexpr bool operator!=(LatPointId rhs) const { return value != rhs.value; } + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); //===----------------------------------------------------------------------===// /// `LatSet` identifiers. These are allocated by `Merger::addSet` (and /// by other methods calling that one), and serve as unique identifiers /// for the corresponding `SmallVector` object. -using LatSetId = unsigned; +class LatSetId final { + friend class Merger; + explicit constexpr LatSetId(unsigned value) : value(value) {} + +public: + /// Constructs a new lattice-set identifier with a dedicated + /// known-invalid value. + explicit constexpr LatSetId() : value(detail::kInvalidId) {} + /// Checks whether the lattice-set identifier has the dedicated + /// known-invalid value. + constexpr bool isValid() const { return value != detail::kInvalidId; } + constexpr bool operator==(LatSetId rhs) const { return value == rhs.value; } + constexpr bool operator!=(LatSetId rhs) const { return value != rhs.value; } + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -143,7 +143,7 @@ // void startReduc(ExprId exp, Value val); - bool isReduc() const { return redExp != detail::kInvalidId; } + bool isReduc() const { return redExp.isValid(); } void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); @@ -152,7 +152,7 @@ Value getValidLexInsert() const { return redValidLexInsert; } void startCustomReduc(ExprId exp); - bool isCustomReduc() const { return redCustom != detail::kInvalidId; } + bool isCustomReduc() const { return redCustom.isValid(); } Value getCustomRedId(); void endCustomReduc(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -56,8 +56,7 @@ latticeMerger(numTensors, numLoops, numFilterLoops, maxRank), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(), - redExp(detail::kInvalidId), redCustom(detail::kInvalidId), - redValidLexInsert() {} + redExp(), redCustom(), redValidLexInsert() {} LogicalResult CodegenEnv::initTensorExp() { // Builds the tensor expression for the Linalg operation in SSA form. @@ -277,7 +276,7 @@ //===----------------------------------------------------------------------===// void CodegenEnv::startReduc(ExprId exp, Value val) { - assert(!isReduc() && exp != detail::kInvalidId); + assert(!isReduc() && exp.isValid()); redExp = exp; updateReduc(val); } @@ -296,7 +295,7 @@ Value val = redVal; redVal = val; latticeMerger.clearExprValue(redExp); - redExp = detail::kInvalidId; + redExp = ExprId(); return val; } @@ -311,7 +310,7 @@ } void CodegenEnv::startCustomReduc(ExprId exp) { - assert(!isCustomReduc() && exp != detail::kInvalidId); + assert(!isCustomReduc() && exp.isValid()); redCustom = exp; } @@ -322,5 +321,5 @@ void CodegenEnv::endCustomReduc() { assert(isCustomReduc()); - redCustom = detail::kInvalidId; + redCustom = ExprId(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1122,7 +1122,7 @@ linalg::GenericOp op = env.op(); Location loc = op.getLoc(); - if (e == ::mlir::sparse_tensor::detail::kInvalidId) + if (!e.isValid()) return Value(); const TensorExp &exp = env.exp(e); const auto kind = exp.kind; @@ -1157,7 +1157,7 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId ldx, bool atStart) { - if (exp == ::mlir::sparse_tensor::detail::kInvalidId) + if (!exp.isValid()) return; if (env.exp(exp).kind == TensorExp::Kind::kTensor) { // Inspect tensor indices. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -104,14 +104,14 @@ switch (kind) { // Leaf. case TensorExp::Kind::kTensor: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); + assert(x != detail::kInvalidId && !y.isValid() && !v && !o); tensor = x; return; case TensorExp::Kind::kInvariant: - assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); + assert(x == detail::kInvalidId && !y.isValid() && v && !o); return; case TensorExp::Kind::kLoopVar: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); + assert(x != detail::kInvalidId && !y.isValid() && !v && !o); loop = x; return; // Unary operations. @@ -135,8 +135,8 @@ case TensorExp::Kind::kNegI: case TensorExp::Kind::kCIm: case TensorExp::Kind::kCRe: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); - children.e0 = x; + assert(x != detail::kInvalidId && !y.isValid() && !v && !o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kTruncF: @@ -150,21 +150,21 @@ case TensorExp::Kind::kCastIdx: case TensorExp::Kind::kTruncI: case TensorExp::Kind::kBitCast: - assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); - children.e0 = x; + assert(x != detail::kInvalidId && !y.isValid() && v && !o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kSelect: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); - children.e0 = x; + assert(x != detail::kInvalidId && !y.isValid() && !v && o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kUnary: // No assertion on y can be made, as the branching paths involve both // a unary (`mapSet`) and binary (`disjSet`) pathway. assert(x != detail::kInvalidId && !v && o); - children.e0 = x; + children.e0 = ExprId{x}; children.e1 = y; return; // Binary operations. @@ -187,14 +187,14 @@ case TensorExp::Kind::kShrS: case TensorExp::Kind::kShrU: case TensorExp::Kind::kShlI: - assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); - children.e0 = x; + assert(x != detail::kInvalidId && y.isValid() && !v && !o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: - assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); - children.e0 = x; + assert(x != detail::kInvalidId && y.isValid() && !v && o); + children.e0 = ExprId{x}; children.e1 = y; return; } @@ -226,37 +226,37 @@ ExprId Merger::addTensorExp(TensorId t) { assert(isValidTensorId(t)); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, - Value(), nullptr); + tensorExps.emplace_back(TensorExp::Kind::kTensor, t, ExprId(), Value(), + nullptr); return eNew; } ExprId Merger::addLoopVarExp(LoopId i) { assert(isValidLoopId(i)); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, - Value(), nullptr); + tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, ExprId(), Value(), + nullptr); return eNew; } ExprId Merger::addInvariantExp(Value v) { const ExprId eNew(tensorExps.size()); tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, - detail::kInvalidId, v, nullptr); + ExprId(), v, nullptr); return eNew; } ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) { assert(k > TensorExp::Kind::kLoopVar); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(k, e0, e1, Value(), op); + tensorExps.emplace_back(k, e0.value, e1, Value(), op); return eNew; } ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) { assert(k > TensorExp::Kind::kLoopVar); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(k, e, detail::kInvalidId, v, op); + tensorExps.emplace_back(k, e.value, ExprId(), v, op); return eNew; } @@ -265,7 +265,7 @@ const unsigned size = numLoops * numTensors; const TensorLoopId b = makeTensorLoopId(t, i); latPoints.emplace_back(size, e); - latPoints[pNew].bits.set(b); + latPoints[pNew.value].bits.set(b); return pNew; } @@ -297,7 +297,7 @@ LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op) { const LatSetId sNew = addSet(); - auto &setNew = latSets[sNew]; + auto &setNew = latSets[sNew.value]; for (const LatPointId p0 : set(s0)) for (const LatPointId p1 : set(s1)) setNew.push_back(conjLat(kind, p0, p1, op)); @@ -308,7 +308,7 @@ Operation *op) { const LatSetId sNew = conjSet(kind, s0, s1, op); // Followed by all in s0. - latSets[sNew].append(latSets[s0]); + latSets[sNew.value].append(latSets[s0.value]); // Map binary 0-y to unary -y. // TODO: move this if-else logic into buildLattices if (kind == TensorExp::Kind::kSubF) @@ -318,7 +318,7 @@ else if (kind == TensorExp::Kind::kSubI) s1 = mapSet(TensorExp::Kind::kNegI, s1); // Followed by all in s1. - latSets[sNew].append(latSets[s1]); + latSets[sNew.value].append(latSets[s1.value]); return sNew; } @@ -332,13 +332,13 @@ if (includeLeft) { if (opleft) s0 = mapSet(ltrans, s0, Value(), opleft); - latSets[sNew].append(latSets[s0]); + latSets[sNew.value].append(latSets[s0.value]); } // Right Region. if (includeRight) { if (opright) s1 = mapSet(rtrans, s1, Value(), opright); - latSets[sNew].append(latSets[s1]); + latSets[sNew.value].append(latSets[s1.value]); } return sNew; } @@ -347,9 +347,9 @@ Operation *op) { assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); const LatSetId sNew = addSet(); - auto &setNew = latSets[sNew]; + auto &setNew = latSets[sNew.value]; for (const LatPointId p : set(s0)) { - const auto &point = latPoints[p]; + const auto &point = latPoints[p.value]; setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); } return sNew; @@ -357,7 +357,7 @@ LatSetId Merger::optimizeSet(LatSetId s0) { const LatSetId sNew = addSet(); - auto &setNew = latSets[sNew]; + auto &setNew = latSets[sNew.value]; const auto &set0 = set(s0); assert(!set0.empty()); const LatPointId p0 = set0[0]; @@ -365,7 +365,7 @@ bool add = true; if (p0 != p1) { // Check whether this is a straightforward copy. - if (expIsTensor(latPoints[p1].exp, outTensor)) + if (expIsTensor(latPoints[p1.value].exp, outTensor)) continue; // Check whether this conjunction is already covered. for (const LatPointId p2 : setNew) { @@ -381,7 +381,7 @@ setNew.push_back(p1); } for (const LatPointId p : setNew) - latPoints[p].simple = simplifyCond(sNew, p); + latPoints[p.value].simple = simplifyCond(sNew, p); return sNew; } @@ -861,7 +861,7 @@ if (hasSparseOut && t == outTensor) t = syntheticTensor; } - latSets[s].push_back(addLat(t, i, e)); + latSets[s.value].push_back(addLat(t, i, e)); return s; } // Unary operations.