diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h @@ -220,8 +220,10 @@ void setElideExpr(bool b) { elideExpr = b; } constexpr SparseTensorDimSliceAttr getSlice() const { return slice; } - /// Checks whether the variables bound/used by this spec are valid - /// with respect to the given ranks. + /// Checks whether the variables bound/used by this spec are valid with + /// respect to the given ranks. Note that null `DimExpr` is considered + /// to be vacuously valid, and therefore calling `setExpr` invalidates + /// the result of this predicate. bool isValid(Ranks const &ranks) const; // TODO(wrengr): Use it or loose it. diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -150,8 +150,6 @@ rhs.printStrong(os); } else { // Combination of all the special rules for addition/subtraction. - // TODO(wrengr): despite being succinct, this is prolly too confusing for - // readers. lhs.printWeak(os); const auto rx = matchNeg(rhs); os << (rx ? " - " : " + "); @@ -180,8 +178,9 @@ : var(var), expr(expr), slice(slice) {} bool DimSpec::isValid(Ranks const &ranks) const { - return ranks.isValid(var) && ranks.isValid(expr); - // TODO(wrengr): is there anything in `slice` that needs validation? + // Nothing in `slice` needs additional validation. + // We explicitly consider null-expr to be vacuously valid. + return ranks.isValid(var) && (!expr || ranks.isValid(expr)); } bool DimSpec::isFunctionOf(VarSet const &vars) const { @@ -220,8 +219,8 @@ } bool LvlSpec::isValid(Ranks const &ranks) const { + // Nothing in `type` needs additional validation. return ranks.isValid(var) && ranks.isValid(expr); - // TODO(wrengr): is there anything in `type` that needs validation? } bool LvlSpec::isFunctionOf(VarSet const &vars) const { diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h @@ -179,7 +179,10 @@ public: constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {} Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {} - Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {} + // TODO(wrengr): Should make the first argument an `ExprKind` instead...? + Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) { + assert(vk != VarKind::Symbol); + } constexpr bool operator==(Var other) const { return impl == other.impl; } constexpr bool operator!=(Var other) const { return !(*this == other); } @@ -345,13 +348,8 @@ enum class ID : unsigned {}; private: - // FUTURE_CL(wrengr): We could use the high-bit of `Var::Impl` to - // store the `std::optional` bit, therefore allowing us to bitbash the - // `num` and `kind` fields together. - // StringRef name; // The bare-id used in the MLIR source. llvm::SMLoc loc; // The location of the first occurence. - // TODO(wrengr): See the above `LocatedVar` note. ID id; // The unique `VarInfo`-identifier. std::optional num; // The unique `Var`-identifier (if resolved). VarKind kind; // The kind of variable. diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -33,10 +33,9 @@ //===----------------------------------------------------------------------===// bool Ranks::isValid(DimLvlExpr expr) const { - // FIXME(wrengr): we have cases without affine expr at an early point - if (!expr.getAffineExpr()) - return true; - // Each `DimLvlExpr` only allows one kind of non-symbol variable. + assert(expr); + // Compute the maximum identifiers for symbol-vars and dim/lvl-vars + // (each `DimLvlExpr` only allows one kind of non-symbol variable). int64_t maxSym = -1, maxVar = -1; // TODO(wrengr): If we run into ASan issues, that may be due to the // "`{{...}}`" syntax; so we may want to try using local-variables instead.