diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h @@ -105,10 +105,14 @@ // TODO(wrengr): Most if not all of these don't actually need to be // methods, they could be free-functions instead. // + Var castAnyVar() const; + std::optional dyn_castAnyVar() const; SymVar castSymVar() const; + std::optional dyn_castSymVar() const; Var castDimLvlVar() const; + std::optional dyn_castDimLvlVar() const; int64_t castConstantValue() const; - std::optional tryGetConstantValue() const; + std::optional dyn_castConstantValue() const; bool hasConstantValue(int64_t val) const; DimLvlExpr getLHS() const; DimLvlExpr getRHS() const; @@ -155,6 +159,12 @@ return expr->getExprKind() == Kind; } constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} + + LvlVar castLvlVar() const { return castDimLvlVar().cast(); } + std::optional dyn_castLvlVar() const { + const auto var = dyn_castDimLvlVar(); + return var ? std::make_optional(var->cast()) : std::nullopt; + } }; static_assert(IsZeroCostAbstraction); @@ -169,6 +179,12 @@ return expr->getExprKind() == Kind; } constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} + + DimVar castDimVar() const { return castDimLvlVar().cast(); } + std::optional dyn_castDimVar() const { + const auto var = dyn_castDimLvlVar(); + return var ? std::make_optional(var->cast()) : std::nullopt; + } }; static_assert(IsZeroCostAbstraction); diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -16,19 +16,46 @@ // `DimLvlExpr` implementation. //===----------------------------------------------------------------------===// +Var DimLvlExpr::castAnyVar() const { + assert(expr && "uninitialized DimLvlExpr"); + const auto var = dyn_castAnyVar(); + assert(var && "expected DimLvlExpr to be a Var"); + return *var; +} + +std::optional DimLvlExpr::dyn_castAnyVar() const { + if (const auto s = expr.dyn_cast_or_null()) + return SymVar(s); + if (const auto x = expr.dyn_cast_or_null()) + return Var(getAllowedVarKind(), x); + return std::nullopt; +} + SymVar DimLvlExpr::castSymVar() const { return SymVar(expr.cast()); } +std::optional DimLvlExpr::dyn_castSymVar() const { + if (const auto s = expr.dyn_cast_or_null()) + return SymVar(s); + return std::nullopt; +} + Var DimLvlExpr::castDimLvlVar() const { return Var(getAllowedVarKind(), expr.cast()); } +std::optional DimLvlExpr::dyn_castDimLvlVar() const { + if (const auto x = expr.dyn_cast_or_null()) + return Var(getAllowedVarKind(), x); + return std::nullopt; +} + int64_t DimLvlExpr::castConstantValue() const { return expr.cast().getValue(); } -std::optional DimLvlExpr::tryGetConstantValue() const { +std::optional DimLvlExpr::dyn_castConstantValue() const { const auto k = expr.dyn_cast_or_null(); return k ? std::make_optional(k.getValue()) : std::nullopt; } @@ -98,7 +125,7 @@ return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val}; } if (op == AffineExprKind::Mul) - if (const auto rval = rhs.tryGetConstantValue(); rval && *rval < 0) + if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0) return MatchNeg{lhs, *rval}; return std::nullopt; }