diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -13,6 +13,8 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ +#include "mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h" + #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -25,8 +27,8 @@ // TODO: These type aliases currently only serve to make the code more // self-documenting, however because they are not type-checked they can -// do nothing to prevent mixups. We should really change them from mere -// aliases to actual struct definitions, so that we can type-check them. +// do nothing to prevent mixups. They should be made into newtypes and +// moved into MergerNewtypes.h with the other newtypes. /// Tensor identifiers. The valid set of identifiers is defined by the /// first argument passed to the `Merger` ctor. @@ -63,25 +65,6 @@ /// just the implementation for a set of `TensorLoopId` values). using TensorLoopId = unsigned; -/// `TensorExp` identifiers. These are allocated by `Merger::addExp`, -/// and serve as unique identifiers for the corresponding `TensorExp` object. -using ExprId = unsigned; - -/// `LatPoint` identifiers. These are allocated by `Merger::addLat`, -/// and serve as unique identifiers for the corresponding `LatPoint` object. -using LatPointId = unsigned; - -/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and -/// by other methods calling that one), and serve as unique identifiers -/// for the corresponding `SmallVector` object. -using LatSetId = unsigned; - -namespace detail { -/// A constant serving as the canonically invalid identifier, regardless -/// of the identifier type. -static constexpr unsigned kInvalidId = -1u; -} // namespace detail - /// Tensor expression. Represents an MLIR expression in tensor index notation. struct TensorExp final { enum class Kind; @@ -304,7 +287,7 @@ /// Constructs a new invariant expression, and returns its identifier. ExprId addInvariantExp(Value v); /// Constructs a new unary or binary expression, and returns its identifier. - ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = detail::kInvalidId, + ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = ExprId(), Operation *op = nullptr); /// Constructs a new sesquinary expression, and returns its identifier. /// Currently no sesquinary `Kind` allows specifying the `op`, but we @@ -554,15 +537,15 @@ /// dangling-reference problems if the loop body inserts new sets. const TensorExp &exp(ExprId e) const { assert(isValidExprId(e)); - return tensorExps[e]; + return tensorExps[e.value]; } const LatPoint &lat(LatPointId p) const { assert(isValidLatPointId(p)); - return latPoints[p]; + return latPoints[p.value]; } ArrayRef set(LatSetId s) const { assert(isValidLatSetId(s)); - return latSets[s]; + return latSets[s.value]; } /// Checks whether the given expression has an associated value. @@ -575,7 +558,7 @@ void setExprValue(ExprId e, Value v) { assert(isValidExprId(e)); assert(v && "Got an undefined value"); - auto &val = tensorExps[e].val; + auto &val = tensorExps[e.value].val; assert(!val && "Expression already has an associated value"); val = v; } @@ -586,7 +569,7 @@ /// then use `updateExprValue` instead. void clearExprValue(ExprId e) { assert(isValidExprId(e)); - auto &val = tensorExps[e].val; + auto &val = tensorExps[e.value].val; assert(val && "Expression does not have an associated value to clear"); val = Value(); } @@ -604,7 +587,7 @@ // provide better invariants. void updateExprValue(ExprId e, Value v) { assert(isValidExprId(e)); - tensorExps[e].val = v; + tensorExps[e.value].val = v; } #ifndef NDEBUG @@ -638,13 +621,13 @@ return isValidTensorId(t) && lvl < lvlToLoop[t].size(); } bool isValidExprId(ExprId e) const { - return e != detail::kInvalidId && e < tensorExps.size(); + return e.isValid() && e.value < tensorExps.size(); } bool isValidLatPointId(LatPointId p) const { - return p != detail::kInvalidId && p < latPoints.size(); + return p.isValid() && p.value < latPoints.size(); } bool isValidLatSetId(LatSetId s) const { - return s != detail::kInvalidId && s < latSets.size(); + return s.isValid() && s.value < latSets.size(); } bool maybeZero(ExprId e) const; bool isInvariant(ExprId e) const { diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h @@ -0,0 +1,118 @@ +//===- MergerNewtypes.h - Newtypes for the `Merger` class -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines a number of "newtypes" (i.e., data types which are +// zero-cost abstractions for wrapping some underlying type while ensuring +// that the compiler keeps the new type distinct from the old type), +// along with related classes for iterating over them, etc. +// +// To guarantee that we achieve the goal of being zero-cost abstractions, +// we assert `std::is_trivially_copyable` and `std::is_trivially_destructible` +// of all these newtypes. These predicates license the compiler to make +// several optimizations; some of which are explicitly documented by the +// C++ standard: +// +// +// However, some key optimizations aren't mentioned by the standard; e.g., +// that trivially-copyable enables passing-by-value, and the conjunction +// of trivially-copyable and trivially-destructible enables passing those +// values in registers rather than on the stack (cf., +// ). +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ +#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ + +#include +#include + +namespace mlir { +namespace sparse_tensor { + +namespace detail { +/// A constant serving as the canonically invalid identifier, +/// regardless of the identifier type. +static constexpr unsigned kInvalidId = -1u; +} // namespace detail + +//===----------------------------------------------------------------------===// +/// `TensorExp` identifiers. These are allocated by `Merger::addExp`, +/// and serve as unique identifiers for the corresponding `TensorExp` object. +class ExprId final { + friend class Merger; + friend struct TensorExp; // Must be "struct" for [-Wmismatched-tags]. + explicit constexpr ExprId(unsigned value) : value(value) {} + +public: + /// Constructs a new expression identifier with a dedicated + /// known-invalid value. + explicit constexpr ExprId() : value(detail::kInvalidId) {} + /// Checks whether the expression identifier has the dedicated + /// known-invalid value. + constexpr bool isValid() const { return value != detail::kInvalidId; } + constexpr bool operator==(ExprId rhs) const { return value == rhs.value; } + constexpr bool operator!=(ExprId rhs) const { return value != rhs.value; } + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + +//===----------------------------------------------------------------------===// +/// `LatPoint` identifiers. These are allocated by `Merger::addLat`, +/// and serve as unique identifiers for the corresponding `LatPoint` object. +class LatPointId final { + friend class Merger; + explicit constexpr LatPointId(unsigned value) : value(value) {} + +public: + /// Constructs a new lattice-point identifier with a dedicated + /// known-invalid value. + explicit constexpr LatPointId() : value(detail::kInvalidId) {} + /// Checks whether the lattice-point identifier has the dedicated + /// known-invalid value. + constexpr bool isValid() const { return value != detail::kInvalidId; } + constexpr bool operator==(LatPointId rhs) const { return value == rhs.value; } + constexpr bool operator!=(LatPointId rhs) const { return value != rhs.value; } + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + +//===----------------------------------------------------------------------===// +/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and +/// by other methods calling that one), and serve as unique identifiers +/// for the corresponding `SmallVector` object. +class LatSetId final { + friend class Merger; + explicit constexpr LatSetId(unsigned value) : value(value) {} + +public: + /// Constructs a new lattice-set identifier with a dedicated + /// known-invalid value. + explicit constexpr LatSetId() : value(detail::kInvalidId) {} + /// Checks whether the lattice-set identifier has the dedicated + /// known-invalid value. + constexpr bool isValid() const { return value != detail::kInvalidId; } + constexpr bool operator==(LatSetId rhs) const { return value == rhs.value; } + constexpr bool operator!=(LatSetId rhs) const { return value != rhs.value; } + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -143,7 +143,7 @@ // void startReduc(ExprId exp, Value val); - bool isReduc() const { return redExp != detail::kInvalidId; } + bool isReduc() const { return redExp.isValid(); } void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); @@ -152,7 +152,7 @@ Value getValidLexInsert() const { return redValidLexInsert; } void startCustomReduc(ExprId exp); - bool isCustomReduc() const { return redCustom != detail::kInvalidId; } + bool isCustomReduc() const { return redCustom.isValid(); } Value getCustomRedId(); void endCustomReduc(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -56,8 +56,7 @@ latticeMerger(numTensors, numLoops, numFilterLoops, maxRank), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(), - redExp(detail::kInvalidId), redCustom(detail::kInvalidId), - redValidLexInsert() {} + redExp(), redCustom(), redValidLexInsert() {} LogicalResult CodegenEnv::initTensorExp() { // Builds the tensor expression for the Linalg operation in SSA form. @@ -277,7 +276,7 @@ //===----------------------------------------------------------------------===// void CodegenEnv::startReduc(ExprId exp, Value val) { - assert(!isReduc() && exp != detail::kInvalidId); + assert(!isReduc() && exp.isValid()); redExp = exp; updateReduc(val); } @@ -296,7 +295,7 @@ Value val = redVal; redVal = val; latticeMerger.clearExprValue(redExp); - redExp = detail::kInvalidId; + redExp = ExprId(); return val; } @@ -311,7 +310,7 @@ } void CodegenEnv::startCustomReduc(ExprId exp) { - assert(!isCustomReduc() && exp != detail::kInvalidId); + assert(!isCustomReduc() && exp.isValid()); redCustom = exp; } @@ -322,5 +321,5 @@ void CodegenEnv::endCustomReduc() { assert(isCustomReduc()); - redCustom = detail::kInvalidId; + redCustom = ExprId(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1122,7 +1122,7 @@ linalg::GenericOp op = env.op(); Location loc = op.getLoc(); - if (e == ::mlir::sparse_tensor::detail::kInvalidId) + if (!e.isValid()) return Value(); const TensorExp &exp = env.exp(e); const auto kind = exp.kind; @@ -1157,7 +1157,7 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId ldx, bool atStart) { - if (exp == ::mlir::sparse_tensor::detail::kInvalidId) + if (!exp.isValid()) return; if (env.exp(exp).kind == TensorExp::Kind::kTensor) { // Inspect tensor indices. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -104,14 +104,14 @@ switch (kind) { // Leaf. case TensorExp::Kind::kTensor: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); + assert(x != detail::kInvalidId && !y.isValid() && !v && !o); tensor = x; return; case TensorExp::Kind::kInvariant: - assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); + assert(x == detail::kInvalidId && !y.isValid() && v && !o); return; case TensorExp::Kind::kLoopVar: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); + assert(x != detail::kInvalidId && !y.isValid() && !v && !o); loop = x; return; // Unary operations. @@ -135,8 +135,8 @@ case TensorExp::Kind::kNegI: case TensorExp::Kind::kCIm: case TensorExp::Kind::kCRe: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); - children.e0 = x; + assert(x != detail::kInvalidId && !y.isValid() && !v && !o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kTruncF: @@ -150,21 +150,21 @@ case TensorExp::Kind::kCastIdx: case TensorExp::Kind::kTruncI: case TensorExp::Kind::kBitCast: - assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); - children.e0 = x; + assert(x != detail::kInvalidId && !y.isValid() && v && !o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kSelect: - assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); - children.e0 = x; + assert(x != detail::kInvalidId && !y.isValid() && !v && o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kUnary: // No assertion on y can be made, as the branching paths involve both // a unary (`mapSet`) and binary (`disjSet`) pathway. assert(x != detail::kInvalidId && !v && o); - children.e0 = x; + children.e0 = ExprId{x}; children.e1 = y; return; // Binary operations. @@ -187,14 +187,14 @@ case TensorExp::Kind::kShrS: case TensorExp::Kind::kShrU: case TensorExp::Kind::kShlI: - assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); - children.e0 = x; + assert(x != detail::kInvalidId && y.isValid() && !v && !o); + children.e0 = ExprId{x}; children.e1 = y; return; case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: - assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); - children.e0 = x; + assert(x != detail::kInvalidId && y.isValid() && !v && o); + children.e0 = ExprId{x}; children.e1 = y; return; } @@ -226,37 +226,37 @@ ExprId Merger::addTensorExp(TensorId t) { assert(isValidTensorId(t)); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, - Value(), nullptr); + tensorExps.emplace_back(TensorExp::Kind::kTensor, t, ExprId(), Value(), + nullptr); return eNew; } ExprId Merger::addLoopVarExp(LoopId i) { assert(isValidLoopId(i)); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, - Value(), nullptr); + tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, ExprId(), Value(), + nullptr); return eNew; } ExprId Merger::addInvariantExp(Value v) { const ExprId eNew(tensorExps.size()); tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, - detail::kInvalidId, v, nullptr); + ExprId(), v, nullptr); return eNew; } ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) { assert(k > TensorExp::Kind::kLoopVar); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(k, e0, e1, Value(), op); + tensorExps.emplace_back(k, e0.value, e1, Value(), op); return eNew; } ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) { assert(k > TensorExp::Kind::kLoopVar); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(k, e, detail::kInvalidId, v, op); + tensorExps.emplace_back(k, e.value, ExprId(), v, op); return eNew; } @@ -265,7 +265,7 @@ const unsigned size = numLoops * numTensors; const TensorLoopId b = makeTensorLoopId(t, i); latPoints.emplace_back(size, e); - latPoints[pNew].bits.set(b); + latPoints[pNew.value].bits.set(b); return pNew; } @@ -297,7 +297,7 @@ LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op) { const LatSetId sNew = addSet(); - auto &setNew = latSets[sNew]; + auto &setNew = latSets[sNew.value]; for (const LatPointId p0 : set(s0)) for (const LatPointId p1 : set(s1)) setNew.push_back(conjLat(kind, p0, p1, op)); @@ -308,7 +308,7 @@ Operation *op) { const LatSetId sNew = conjSet(kind, s0, s1, op); // Followed by all in s0. - latSets[sNew].append(latSets[s0]); + latSets[sNew.value].append(latSets[s0.value]); // Map binary 0-y to unary -y. // TODO: move this if-else logic into buildLattices if (kind == TensorExp::Kind::kSubF) @@ -318,7 +318,7 @@ else if (kind == TensorExp::Kind::kSubI) s1 = mapSet(TensorExp::Kind::kNegI, s1); // Followed by all in s1. - latSets[sNew].append(latSets[s1]); + latSets[sNew.value].append(latSets[s1.value]); return sNew; } @@ -332,13 +332,13 @@ if (includeLeft) { if (opleft) s0 = mapSet(ltrans, s0, Value(), opleft); - latSets[sNew].append(latSets[s0]); + latSets[sNew.value].append(latSets[s0.value]); } // Right Region. if (includeRight) { if (opright) s1 = mapSet(rtrans, s1, Value(), opright); - latSets[sNew].append(latSets[s1]); + latSets[sNew.value].append(latSets[s1.value]); } return sNew; } @@ -347,9 +347,9 @@ Operation *op) { assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); const LatSetId sNew = addSet(); - auto &setNew = latSets[sNew]; + auto &setNew = latSets[sNew.value]; for (const LatPointId p : set(s0)) { - const auto &point = latPoints[p]; + const auto &point = latPoints[p.value]; setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); } return sNew; @@ -357,7 +357,7 @@ LatSetId Merger::optimizeSet(LatSetId s0) { const LatSetId sNew = addSet(); - auto &setNew = latSets[sNew]; + auto &setNew = latSets[sNew.value]; const auto &set0 = set(s0); assert(!set0.empty()); const LatPointId p0 = set0[0]; @@ -365,7 +365,7 @@ bool add = true; if (p0 != p1) { // Check whether this is a straightforward copy. - if (expIsTensor(latPoints[p1].exp, outTensor)) + if (expIsTensor(latPoints[p1.value].exp, outTensor)) continue; // Check whether this conjunction is already covered. for (const LatPointId p2 : setNew) { @@ -381,7 +381,7 @@ setNew.push_back(p1); } for (const LatPointId p : setNew) - latPoints[p].simple = simplifyCond(sNew, p); + latPoints[p.value].simple = simplifyCond(sNew, p); return sNew; } @@ -861,7 +861,7 @@ if (hasSparseOut && t == outTensor) t = syntheticTensor; } - latSets[s].push_back(addLat(t, i, e)); + latSets[s.value].push_back(addLat(t, i, e)); return s; } // Unary operations.