diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h @@ -94,56 +94,105 @@ //===----------------------------------------------------------------------===// /// A concrete variable, to be used in our variant of `AffineExpr`. class Var { + // Design Note: This class makes several distinctions which may at first + // seem unnecessary but are in fact needed for implementation reasons. + // These distinctions are summarized as follows: + // + // * `Var` + // Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI + // support for subclasses with a fixed `VarKind`. + // * `Var::Num` + // Client-facing typedef for the type of variable numbers; defined + // so that client code can use it to disambiguate/document when things + // are intended to be variable numbers, as opposed to some other thing + // which happens to be represented as `unsigned`. + // * `Var::Storage` + // Private typedef for the storage of `Var::Impl`; defined only because + // it's also needed for defining `kMaxNum`. Note that this type must be + // kept distinct from `Var::Num`: not only can they be different C++ types + // (even though they currently happen to be the same), but also because + // they use different bitwise representations. + // * `Var::Impl` + // The underlying implementation of `Var`; needed by RTTI to serve as + // an intermediary between `Var` and `Var::Storage`. That is, we want + // the RTTI methods to select the `U(Var::Impl)` ctor, without any + // possibility of confusing that with the `U(Var::Num)` ctor nor with + // the copy-ctor. (Although the `U(Var::Impl)` ctor is effectively + // identical to the copy-ctor, it doesn't have the type that C++ expects + // for a copy-ctor.) + // + // TODO: See if it'd be cleaner to use "llvm/ADT/Bitfields.h" in lieu + // of doing our own bitbashing (though that seems to only be used by LLVM + // for defining machine/assembly ops, and not anywhere else in LLVM/MLIR). public: - /// Typedef to help disambiguate different uses of `unsigned`. + /// Typedef for the type of variable numbers. using Num = unsigned; private: - /// The underlying storage representation of `Var`. Note that this type - /// should be kept distinct from `Num`. Not only can they be different - /// C++ types (even though they currently happen to be the same), but - /// they also use different bitwise representations. - // - // FUTURE_CL(wrengr): Rather than rolling our own, we should - // consider using "llvm/ADT/Bitfields.h"; though that seems to only - // be used by LLVM for the sake of defining machine/assembly ops. - // Or we could consider abusing `PointerIntPair`... - using Impl = unsigned; - Impl impl; - - /// The largest `Var::Num` supported by `Var::Impl`. Two low-order - /// bits are reserved for storing the `VarKind`, and one high-order bit - /// is reserved for future use (e.g., to support `DenseMapInfo` while - /// maintaining the usual numeric values for "empty" and "tombstone"). + /// Typedef for the underlying storage of `Var::Impl`. + using Storage = unsigned; + + /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`. + /// Two low-order bits are reserved for storing the `VarKind`, + /// and one high-order bit is reserved for future use (e.g., to support + /// `DenseMapInfo` while maintaining the usual numeric values for + /// "empty" and "tombstone"). static constexpr Num kMaxNum = - static_cast(std::numeric_limits::max() >> 3); + static_cast(std::numeric_limits::max() >> 3); public: + /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`. + // // This must be public for `VarInfo` to use it (whereas we don't want // to expose the `impl` field via friendship). static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; } - constexpr Var(VarKind vk, Num n) - : impl((static_cast(n) << 2) | - static_cast(to_underlying(vk))) { - assert(isWF(vk) && "unknown VarKind"); - assert(isWF_Num(n) && "Var::Num is too large"); - } +protected: + /// The underlying implementation of `Var`. Note that this must be kept + /// distinct from `Var` itself, since we want to ensure that the RTTI + /// methods will select the `U(Var::Impl)` ctor rather than selecting + /// the `U(Var::Num)` ctor. + class Impl final { + Storage data; + + public: + constexpr Impl(VarKind vk, Num n) + : data((static_cast(n) << 2) | + static_cast(to_underlying(vk))) { + assert(isWF(vk) && "unknown VarKind"); + assert(isWF_Num(n) && "Var::Num is too large"); + } + constexpr bool operator==(Impl other) const { return data == other.data; } + constexpr bool operator!=(Impl other) const { return !(*this == other); } + constexpr VarKind getKind() const { return static_cast(data & 3); } + constexpr Num getNum() const { return static_cast(data >> 2); } + }; + static_assert(IsZeroCostAbstraction); + +private: + Impl impl; + +protected: + /// Protected ctor for the RTTI methods to use. + constexpr explicit Var(Impl impl) : impl(impl) {} + +public: + constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {} Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {} Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {} constexpr bool operator==(Var other) const { return impl == other.impl; } constexpr bool operator!=(Var other) const { return !(*this == other); } - constexpr VarKind getKind() const { return static_cast(impl & 3); } - constexpr Num getNum() const { return static_cast(impl >> 2); } + constexpr VarKind getKind() const { return impl.getKind(); } + constexpr Num getNum() const { return impl.getNum(); } template constexpr bool isa() const; template constexpr U cast() const; template - constexpr U dyn_cast() const; + constexpr std::optional dyn_cast() const; void print(llvm::raw_ostream &os) const; void print(AsmPrinter &printer) const; @@ -152,6 +201,7 @@ static_assert(IsZeroCostAbstraction); class SymVar final : public Var { + using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. public: static constexpr VarKind Kind = VarKind::Symbol; static constexpr bool classof(Var const *var) { @@ -163,6 +213,7 @@ static_assert(IsZeroCostAbstraction); class DimVar final : public Var { + using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. public: static constexpr VarKind Kind = VarKind::Dimension; static constexpr bool classof(Var const *var) { @@ -174,6 +225,7 @@ static_assert(IsZeroCostAbstraction); class LvlVar final : public Var { + using Var::Var; // inherit `Var(Impl)` ctor for RTTI use. public: static constexpr VarKind Kind = VarKind::Level; static constexpr bool classof(Var const *var) { @@ -202,12 +254,14 @@ template constexpr U Var::cast() const { assert(isa()); - return U(impl >> 2); // NOTE TO Wren: confirm this fix + // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)` + return U(impl); } template -constexpr U Var::dyn_cast() const { - return isa() ? U(impl >> 2) : U(); +constexpr std::optional Var::dyn_cast() const { + // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)` + return isa() ? std::make_optional(U(impl)) : std::nullopt; } //===----------------------------------------------------------------------===//