diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt @@ -52,6 +52,10 @@ add_mlir_dialect_library(MLIRSparseTensorDialect SparseTensorDialect.cpp + Detail/Var.cpp + Detail/DimLvlMap.cpp + Detail/LvlTypeParser.cpp + Detail/DimLvlMapParser.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h @@ -0,0 +1,317 @@ +//===- DimLvlMap.h ----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// FIXME(wrengr): The `DimLvlMap` class must be public so that it can +// be named as the storage representation of the parameter for the tblgen +// defn of STEA. We may well need to make the other classes public too, +// so that the rest of the compiler can use them when necessary. +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H +#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H + +#include "Var.h" + +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" + +namespace mlir { +namespace sparse_tensor { +namespace ir_detail { + +//===----------------------------------------------------------------------===// +// TODO(wrengr): Give this enum a better name, so that it fits together +// with the name of the `DimLvlExpr` class (which may also want a better +// name). Perhaps make this a nested-type too. +// +// NOTE: In the future we will extend this enum to include "counting +// expressions" required for supporting ITPACK/ELL. Therefore the current +// underlying-type and representation values should not be relied upon. +enum class ExprKind : bool { Dimension = false, Level = true }; + +// TODO(wrengr): still needs a better name.... +constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) { + using VK = std::underlying_type_t; + return VarKind{2 * static_cast(!to_underlying(ek))}; +} +static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level && + getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension); + +//===----------------------------------------------------------------------===// +// TODO(wrengr): The goal of this class is to capture a proof that +// we've verified that the given `AffineExpr` only has variables of the +// appropriate kind(s). So we need to actually prove/verify that in the +// ctor or all its callsites! +class DimLvlExpr { +private: + // FIXME(wrengr): Per , + // the `kind` field should be private and const. However, beware + // that if we mark any field as `const` or if the fields have differing + // `private`/`protected` privileges then the `IsZeroCostAbstraction` + // assertion will fail! + // (Also, iirc, if we end up moving the `expr` to the subclasses + // instead, that'll also cause `IsZeroCostAbstraction` to fail.) + ExprKind kind; + AffineExpr expr; + +public: + constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {} + + // + // Boolean operators. + // + constexpr bool operator==(DimLvlExpr other) const { + return kind == other.kind && expr == other.expr; + } + constexpr bool operator!=(DimLvlExpr other) const { + return !(*this == other); + } + constexpr explicit operator bool() const { return static_cast(expr); } + + // + // RTTI support (for the `DimLvlExpr` class itself). + // + template constexpr bool isa() const; + template constexpr U cast() const; + template constexpr U dyn_cast() const; + + // + // Simple getters. + // + constexpr ExprKind getExprKind() const { return kind; } + constexpr VarKind getAllowedVarKind() const { + return getVarKindAllowedInExpr(kind); + } + constexpr AffineExpr getAffineExpr() const { return expr; } + AffineExprKind getAffineKind() const { + assert(expr); + return expr.getKind(); + } + MLIRContext *getContext() const { return expr ? expr.getContext() : nullptr; } + + // + // Getters for handling `AffineExpr` subclasses. + // + // TODO(wrengr): is there any way to make these typesafe without too much + // templating? + // TODO(wrengr): Most if not all of these don't actually need to be + // methods, they could be free-functions instead. + // + SymVar castSymVar() const; + Var castDimLvlVar() const; + int64_t castConstantValue() const; + std::optional tryGetConstantValue() const; + bool hasConstantValue(int64_t val) const; + DimLvlExpr getLHS() const; + DimLvlExpr getRHS() const; + std::tuple unpackBinop() const; + + /// Checks whether the variables bound/used by this spec are valid + /// with respect to the given ranks. + bool isValid(Ranks const &ranks) const; + + void print(llvm::raw_ostream &os) const; + void print(AsmPrinter &printer) const; + void dump() const; + +protected: + // Variant of `mlir::AsmPrinter::Impl::BindingStrength` + enum class BindingStrength : bool { Weak = false, Strong = true }; + + // TODO(wrengr): Does our version of `printAffineExprInternal` really + // need to be a method, or could it be a free-function instead? (assuming + // `BindingStrength` goes with it). + void printAffineExprInternal(llvm::raw_ostream &os, + BindingStrength enclosingTightness) const; + void printStrong(llvm::raw_ostream &os) const { + printAffineExprInternal(os, BindingStrength::Strong); + } + void printWeak(llvm::raw_ostream &os) const { + printAffineExprInternal(os, BindingStrength::Weak); + } +}; +static_assert(IsZeroCostAbstraction); + +// FUTURE_CL(wrengr): It would be nice to have the subclasses override +// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them +// the proper covariant return types. +// +class DimExpr final : public DimLvlExpr { + // FIXME(wrengr): These two are needed for the current RTTI implementation. + friend class DimLvlExpr; + constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {} + +public: + static constexpr ExprKind Kind = ExprKind::Dimension; + static constexpr bool classof(DimLvlExpr const *expr) { + return expr->getExprKind() == Kind; + } + constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} +}; +static_assert(IsZeroCostAbstraction); + +class LvlExpr final : public DimLvlExpr { + // FIXME(wrengr): These two are needed for the current RTTI implementation. + friend class DimLvlExpr; + constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {} + +public: + static constexpr ExprKind Kind = ExprKind::Level; + static constexpr bool classof(DimLvlExpr const *expr) { + return expr->getExprKind() == Kind; + } + constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {} +}; +static_assert(IsZeroCostAbstraction); + +// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions +template constexpr bool DimLvlExpr::isa() const { + if constexpr (std::is_same_v) + return getExprKind() == ExprKind::Dimension; + if constexpr (std::is_same_v) + return getExprKind() == ExprKind::Level; +} + +template constexpr U DimLvlExpr::cast() const { + assert(isa()); + return U(*this); +} + +template constexpr U DimLvlExpr::dyn_cast() const { + return isa() ? U(*this) : U(); +} + +//===----------------------------------------------------------------------===// +/// The full `dimVar = dimExpr : dimSlice` specification for a given dimension. +class DimSpec final { + /// The dimension-variable bound by this specification. + DimVar var; + /// The dimension-expression. The `DimSpec` ctor treats this field + /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify) + /// the expression via function-inversion inference. + DimExpr expr; + /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes + /// not (though if `expr` is null it will elide printing that); whereas + /// the `DimLvlMap` ctor will reset it as appropriate. + bool elideExpr = false; + /// The dimension-slice; optional, default is null. + SparseTensorDimSliceAttr slice; + +public: + DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice); + + constexpr DimVar getBoundVar() const { return var; } + constexpr bool hasExpr() const { return static_cast(expr); } + constexpr DimExpr getExpr() const { return expr; } + void setExpr(DimExpr newExpr) { + assert(!hasExpr()); + expr = newExpr; + } + constexpr bool canElideExpr() const { return elideExpr; } + void setElideExpr(bool b) { elideExpr = b; } + constexpr SparseTensorDimSliceAttr getSlice() const { return slice; } + + /// Checks whether the variables bound/used by this spec are valid + /// with respect to the given ranks. + bool isValid(Ranks const &ranks) const; + + // TODO(wrengr): Use it or loose it. + bool isFunctionOf(Var var) const; + bool isFunctionOf(VarSet const &vars) const; + void getFreeVars(VarSet &vars) const; + + void print(llvm::raw_ostream &os, bool wantElision = true) const; + void print(AsmPrinter &printer, bool wantElision = true) const; + void dump() const; +}; +// Although this class is more than just a newtype/wrapper, we do want +// to ensure that storing them into `SmallVector` is efficient. +static_assert(IsZeroCostAbstraction); + +//===----------------------------------------------------------------------===// +/// The full `lvlVar = lvlExpr : lvlType` specification for a given level. +class LvlSpec final { + /// The level-variable bound by this specification. + LvlVar var; + /// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not; + /// whereas the `DimLvlMap` ctor will reset this as appropriate. + bool elideVar = false; + /// The level-expression. + // + // NOTE: For now we use `LvlExpr` because all level-expressions must be + // `AffineExpr`; however, in the future we will also want to allow "counting + // expressions", and potentially other kinds of non-affine level-expressions. + // Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`, + // so we may consider defining another class for pairing those two together + // to ensure that the pair is well-formed. + LvlExpr expr; + /// The level-type (== level-format + lvl-properties). + DimLevelType type; + +public: + LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type); + + constexpr LvlVar getBoundVar() const { return var; } + constexpr bool canElideVar() const { return elideVar; } + void setElideVar(bool b) { elideVar = b; } + constexpr LvlExpr getExpr() const { return expr; } + constexpr DimLevelType getType() const { return type; } + + /// Checks whether the variables bound/used by this spec are valid + /// with respect to the given ranks. + // + // NOTE: Once we introduce "counting expressions" this will need + // a more sophisticated implementation than `DimSpec::isValid` does. + bool isValid(Ranks const &ranks) const; + + // TODO(wrengr): Use it or loose it. + bool isFunctionOf(Var var) const; + bool isFunctionOf(VarSet const &vars) const; + void getFreeVars(VarSet &vars) const; + + void print(llvm::raw_ostream &os, bool wantElision = true) const; + void print(AsmPrinter &printer, bool wantElision = true) const; + void dump() const; +}; +// Although this class is more than just a newtype/wrapper, we do want +// to ensure that storing them into `SmallVector` is efficient. +static_assert(IsZeroCostAbstraction); + +//===----------------------------------------------------------------------===// +class DimLvlMap final { + // TODO(wrengr): Need to define getters + unsigned symRank; + SmallVector dimSpecs; + SmallVector lvlSpecs; + + // Checks for integrity of variable-binding structure. + // This is already called by the ctor. + bool isWF() const; + +public: + DimLvlMap(unsigned symRank, ArrayRef dimSpecs, + ArrayRef lvlSpecs); + + unsigned getSymRank() const { return symRank; } + unsigned getDimRank() const { return dimSpecs.size(); } + unsigned getLvlRank() const { return lvlSpecs.size(); } + unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); } + Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; } + + DimLevelType getDimLevelType(unsigned i) { return lvlSpecs[i].getType(); } + + void print(llvm::raw_ostream &os, bool wantElision = true) const; + void print(AsmPrinter &printer, bool wantElision = true) const; + void dump() const; +}; + +//===----------------------------------------------------------------------===// + +} // namespace ir_detail +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -0,0 +1,322 @@ +//===- DimLvlMap.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DimLvlMap.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; +using namespace mlir::sparse_tensor::ir_detail; + +//===----------------------------------------------------------------------===// +// `DimLvlExpr` implementation. +//===----------------------------------------------------------------------===// + +SymVar DimLvlExpr::castSymVar() const { + return SymVar(expr.cast()); +} + +Var DimLvlExpr::castDimLvlVar() const { + return Var(getAllowedVarKind(), expr.cast()); +} + +int64_t DimLvlExpr::castConstantValue() const { + return expr.cast().getValue(); +} + +std::optional DimLvlExpr::tryGetConstantValue() const { + const auto k = expr.dyn_cast_or_null(); + return k ? std::make_optional(k.getValue()) : std::nullopt; +} + +// This helper method is akin to `AffineExpr::operator==(int64_t)` +// except it uses a different implementation, namely the implementation +// used within `AsmPrinter::Impl::printAffineExprInternal`. +// +// wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses +// this implementation because it avoids constructing the intermediate +// `AffineConstantExpr(val)` and thus should in theory be a bit faster. +// However, if it is indeed faster, then the `AffineExpr::operator==` +// method should be updated to do this instead. And if it isn't any +// faster, then we should be using `AffineExpr::operator==` instead. +bool DimLvlExpr::hasConstantValue(int64_t val) const { + const auto k = expr.dyn_cast_or_null(); + return k && k.getValue() == val; +} + +DimLvlExpr DimLvlExpr::getLHS() const { + const auto binop = expr.dyn_cast_or_null(); + return DimLvlExpr(kind, binop ? binop.getLHS() : nullptr); +} + +DimLvlExpr DimLvlExpr::getRHS() const { + const auto binop = expr.dyn_cast_or_null(); + return DimLvlExpr(kind, binop ? binop.getRHS() : nullptr); +} + +std::tuple +DimLvlExpr::unpackBinop() const { + const auto ak = getAffineKind(); + const auto binop = expr.dyn_cast(); + const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr); + const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr); + return {lhs, ak, rhs}; +} + +void DimLvlExpr::dump() const { + print(llvm::errs()); + llvm::errs() << "\n"; +} +void DimLvlExpr::print(AsmPrinter &printer) const { + print(printer.getStream()); +} +void DimLvlExpr::print(llvm::raw_ostream &os) const { + if (!expr) + os << "<>"; + else + printWeak(os); +} + +namespace { +struct MatchNeg final : public std::pair { + using Base = std::pair; + using Base::Base; + constexpr DimLvlExpr getLHS() const { return first; } + constexpr int64_t getRHS() const { return second; } +}; +} // namespace + +static std::optional matchNeg(DimLvlExpr expr) { + const auto [lhs, op, rhs] = expr.unpackBinop(); + if (op == AffineExprKind::Constant) { + const auto val = expr.castConstantValue(); + if (val < 0) + return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val}; + } + if (op == AffineExprKind::Mul) + if (const auto rval = rhs.tryGetConstantValue(); rval && *rval < 0) + return MatchNeg{lhs, *rval}; + return std::nullopt; +} + +// A heavily revised version of `AsmPrinter::Impl::printAffineExprInternal`. +void DimLvlExpr::printAffineExprInternal( + llvm::raw_ostream &os, BindingStrength enclosingTightness) const { + const char *binopSpelling = nullptr; + switch (getAffineKind()) { + case AffineExprKind::SymbolId: + os << castSymVar(); + return; + case AffineExprKind::DimId: + os << castDimLvlVar(); + return; + case AffineExprKind::Constant: + os << castConstantValue(); + return; + case AffineExprKind::Add: + binopSpelling = " + "; // N.B., this is unused + break; + case AffineExprKind::Mul: + binopSpelling = " * "; + break; + case AffineExprKind::FloorDiv: + binopSpelling = " floordiv "; + break; + case AffineExprKind::CeilDiv: + binopSpelling = " ceildiv "; + break; + case AffineExprKind::Mod: + binopSpelling = " mod "; + break; + } + + if (enclosingTightness == BindingStrength::Strong) + os << '('; + + const auto [lhs, op, rhs] = unpackBinop(); + if (op == AffineExprKind::Mul && rhs.hasConstantValue(-1)) { + // Pretty print `(lhs * -1)` as "-lhs". + os << '-'; + lhs.printStrong(os); + } else if (op != AffineExprKind::Add) { + // Default rule for tightly binding binary operators. + // (Including `Mul` that didn't match the previous rule.) + lhs.printStrong(os); + os << binopSpelling; + rhs.printStrong(os); + } else { + // Combination of all the special rules for addition/subtraction. + // TODO(wrengr): despite being succinct, this is prolly too confusing for + // readers. + lhs.printWeak(os); + const auto rx = matchNeg(rhs); + os << (rx ? " - " : " + "); + const auto &rlhs = rx ? rx->getLHS() : rhs; + const auto rrhs = rx ? rx->getRHS() : -1; // value irrelevant when `!rx` + const bool nonunit = rrhs != -1; // value irrelevant when `!rx` + const bool isStrong = + rx && rlhs && (nonunit || rlhs.getAffineKind() == AffineExprKind::Add); + if (rlhs) + rlhs.printAffineExprInternal(os, BindingStrength{isStrong}); + if (rx && rlhs && nonunit) + os << " * "; + if (rx && (!rlhs || nonunit)) + os << -rrhs; + } + + if (enclosingTightness == BindingStrength::Strong) + os << ')'; +} + +//===----------------------------------------------------------------------===// +// `DimSpec` implementation. +//===----------------------------------------------------------------------===// + +DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice) + : var(var), expr(expr), slice(slice) {} + +bool DimSpec::isValid(Ranks const &ranks) const { + return ranks.isValid(var) && ranks.isValid(expr); + // TODO(wrengr): is there anything in `slice` that needs validation? +} + +bool DimSpec::isFunctionOf(VarSet const &vars) const { + return vars.occursIn(expr); +} + +void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); } + +void DimSpec::dump() const { + print(llvm::errs(), /*wantElision=*/false); + llvm::errs() << "\n"; +} +void DimSpec::print(AsmPrinter &printer, bool wantElision) const { + print(printer.getStream(), wantElision); +} +void DimSpec::print(llvm::raw_ostream &os, bool wantElision) const { + os << var; + if (expr && (!wantElision || !elideExpr)) + os << " = " << expr; + if (slice) { + os << " : "; + // Call `SparseTensorDimSliceAttr::print` directly, to avoid + // printing the mnemonic. + slice.print(os); + } +} + +//===----------------------------------------------------------------------===// +// `LvlSpec` implementation. +//===----------------------------------------------------------------------===// + +LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type) + : var(var), expr(expr), type(type) { + assert(expr); + assert(isValidDLT(type) && !isUndefDLT(type)); +} + +bool LvlSpec::isValid(Ranks const &ranks) const { + return ranks.isValid(var) && ranks.isValid(expr); + // TODO(wrengr): is there anything in `type` that needs validation? +} + +bool LvlSpec::isFunctionOf(VarSet const &vars) const { + return vars.occursIn(expr); +} + +void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); } + +void LvlSpec::dump() const { + print(llvm::errs(), /*wantElision=*/false); + llvm::errs() << "\n"; +} +void LvlSpec::print(AsmPrinter &printer, bool wantElision) const { + print(printer.getStream(), wantElision); +} +void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const { + if (!wantElision || !elideVar) + os << var << " = "; + os << expr; + os << ": \"" << toMLIRString(type) << "\""; +} + +//===----------------------------------------------------------------------===// +// `DimLvlMap` implementation. +//===----------------------------------------------------------------------===// + +DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef dimSpecs, + ArrayRef lvlSpecs) + : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs) { + // First, check integrity of the variable-binding structure. + assert(isWF()); + + // TODO: Second, we need to infer/validate the `lvlToDim` mapping. + // Along the way we should set every `DimSpec::elideExpr` according + // to whether the given expression is inferable or not. Notably, this + // needs to happen before the code for setting every `LvlSpec::elideVar`, + // since if the LvlVar is only used in elided DimExpr, then the + // LvlVar should also be elided. + + // Third, we set every `LvlSpec::elideVar` according to whether that + // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr). + VarSet usedVars(getRanks()); + for (const auto &dimSpec : dimSpecs) + // NOTE TO Wren: bypassed for empty + if (dimSpec.hasExpr() && !dimSpec.canElideExpr()) + usedVars.add(dimSpec.getExpr()); + for (auto &lvlSpec : this->lvlSpecs) + lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar())); +} + +bool DimLvlMap::isWF() const { + const auto ranks = getRanks(); + unsigned dimNum = 0; + for (const auto &dimSpec : dimSpecs) + if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks)) + return false; + assert(dimNum == ranks.getDimRank()); + unsigned lvlNum = 0; + for (const auto &lvlSpec : lvlSpecs) + if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks)) + return false; + assert(lvlNum == ranks.getLvlRank()); + return true; +} + +void DimLvlMap::dump() const { + print(llvm::errs(), /*wantElision=*/false); + llvm::errs() << "\n"; +} +void DimLvlMap::print(AsmPrinter &printer, bool wantElision) const { + print(printer.getStream(), wantElision); +} +void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const { + // Symbolic identifiers. + // NOTE: Unlike `AffineMap` we place the SymVar bindings before the DimVar + // bindings, since the SymVars may occur within DimExprs and thus this + // ordering helps reduce potential user confusion about the scope of bidings + // (since it means SymVars and DimVars both bind-forward in the usual way, + // whereas only LvlVars have different binding rules). + if (symRank != 0) { + os << "[s0"; + for (unsigned i = 1; i < symRank; ++i) + os << ", s" << i; + os << ']'; + } + + // Dimension specifiers. + os << '('; + llvm::interleaveComma( + dimSpecs, os, [&](DimSpec const &spec) { spec.print(os, wantElision); }); + os << ") -> ("; + // Level specifiers. + llvm::interleaveComma( + lvlSpecs, os, [&](LvlSpec const &spec) { spec.print(os, wantElision); }); + os << ')'; +} + +//===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h @@ -0,0 +1,63 @@ +//===- DimLvlMapParser.h - `DimLvlMap` parser -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H +#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H + +#include "DimLvlMap.h" +#include "LvlTypeParser.h" + +namespace mlir { +namespace sparse_tensor { +namespace ir_detail { + +//===----------------------------------------------------------------------===// +// NOTE(wrengr): The idea here was originally based on the +// "lib/AsmParser/AffineParser.cpp"-static class `AffineParser`. +// Unfortunately, we can't use that class directly since it's file-local. +// Even worse, both `mlir::detail::Parser` and `mlir::detail::ParserState` +// are also file-local classes. I've been attempting to convert things +// over to using `AsmParser` wherever possible, though it's not clear that +// that'll work... +class DimLvlMapParser final { +public: + explicit DimLvlMapParser(AsmParser &parser) : parser(parser) {} + + // Parses the input for a sparse tensor dimension-level map + // and returns the map on success. + FailureOr parseDimLvlMap(); + +private: + // TODO(wrengr): rather than using `OptionalParseResult` and two + // out-parameters, should we define a type to encapsulate all that? + OptionalParseResult parseVar(VarKind vk, bool isOptional, + CreationPolicy creationPolicy, VarInfo::ID &id, + bool &didCreate); + FailureOr parseVarUsage(VarKind vk); + FailureOr> parseVarBinding(VarKind vk, bool isOptional); + + ParseResult parseOptionalSymbolIdList(); + ParseResult parseDimSpec(); + ParseResult parseDimSpecList(); + ParseResult parseLvlSpec(); + ParseResult parseLvlSpecList(); + + AsmParser &parser; + LvlTypeParser lvlTypeParser; + VarEnv env; + SmallVector dimSpecs; + SmallVector lvlSpecs; +}; + +//===----------------------------------------------------------------------===// + +} // namespace ir_detail +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -0,0 +1,259 @@ +//===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===// +// These two lookup methods are probably small enough to benefit from +// being defined inline/in-class, expecially since doing so may allow the +// compiler to optimize the `std::optional` away. But we put the defns +// here until benchmarks prove the benefit of doing otherwise. +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DimLvlMapParser.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; +using namespace mlir::sparse_tensor::ir_detail; + +//===----------------------------------------------------------------------===// +// TODO(wrengr): rephrase these to do the trick for gobbling up any trailing +// semicolon +// +// NOTE: There's no way for `FAILURE_IF_FAILED` to simultaneously support +// both `OptionalParseResult` and `InFlightDiagnostic` return types. +// We can get the compiler to accept the code if we returned "`{}`", +// however for `OptionalParseResult` that would become the nullopt result, +// whereas for `InFlightDiagnostic` it would become a result that can +// be implicitly converted to success. By using "`failure()`" we ensure +// that `OptionalParseResult` behaves as intended, however that means the +// macro cannot be used for `InFlightDiagnostic` since there's no implicit +// conversion. +#define FAILURE_IF_FAILED(STMT) \ + if (failed(STMT)) { \ + return failure(); \ + } + +// Although `ERROR_IF` is phrased to return `InFlightDiagnostic`, that type +// can be implicitly converted to all four of `LogicalResult, `FailureOr`, +// `ParseResult`, and `OptionalParseResult`. (However, beware that the +// conversion to `OptionalParseResult` doesn't properly delegate to +// `InFlightDiagnostic::operator ParseResult`.) +// +// NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope. +// NOTE_TO_SELF(wrengr): The LOC used to always be `parser.getNameLoc()` +#define ERROR_IF(COND, MSG) \ + if (COND) { \ + return parser.emitError(loc, MSG); \ + } + +//===----------------------------------------------------------------------===// +// `DimLvlMapParser` implementation for variable parsing. +//===----------------------------------------------------------------------===// + +// Our variation on `AffineParser::{parseBareIdExpr,parseIdentifierDefinition}` +OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional, + CreationPolicy creationPolicy, + VarInfo::ID &varID, + bool &didCreate) { + // Save the current location so that we can have error messages point to + // the right place. Note that `Parser::emitWrongTokenError` starts off + // with the same location as `AsmParserImpl::getCurrentLocation` returns; + // however, `Parser` will then do some various munging with the location + // before actually using it, so `AsmParser::emitError` can't quite be used + // as a drop-in replacement for `Parser::emitWrongTokenError`. + const auto loc = parser.getCurrentLocation(); + + // Several things to note. + // (1) the `Parser::isCurrentTokenAKeyword` method checks the exact + // same conditions as the `AffineParser.cpp`-static free-function + // `isIdentifier` which is used by `AffineParser::parseBareIdExpr`. + // (2) the `{Parser,AsmParserImpl}::parseOptionalKeyword(StringRef*)` + // methods do the same song and dance about using + // `isCurrentTokenAKeyword`, `getTokenSpelling`, et `consumeToken` as we + // would want to do if we could use the `Parser` class directly. It + // doesn't provide the nice error handling we want, but we can work around + // that. + StringRef name; + if (failed(parser.parseOptionalKeyword(&name))) { + // If not actually optional, then `emitError`. + ERROR_IF(!isOptional, "expected bare identifier") + // If is actually optional, then return the null `OptionalParseResult`. + return std::nullopt; + } + + // I don't know if we need to worry about the possibility of the caller + // recovering from error and then reusing the `DimLvlMapParser` for subsequent + // `parseVar`, but I'm erring on the side of caution by distinguishing + // all three possible creation policies. + if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) { + varID = res->first; + didCreate = res->second; + return success(); + } + // TODO(wrengr): these error messages make sense for our intended usage, + // but not in general; but it's unclear how best to factor that part out. + switch (creationPolicy) { + case CreationPolicy::MustNot: + return parser.emitError(loc, "use of undeclared identifier '" + name + "'"); + case CreationPolicy::May: + llvm_unreachable("got nullopt for CreationPolicy::May"); + case CreationPolicy::Must: + return parser.emitError(loc, "redefinition of identifier '" + name + "'"); + } + llvm_unreachable("unknown CreationPolicy"); +} + +FailureOr DimLvlMapParser::parseVarUsage(VarKind vk) { + VarInfo::ID varID; + bool didCreate; + // We use the policy `May` because we want to allow parsing free/unbound + // variables. If we wanted to distinguish between parsing free-var uses + // vs bound-var uses, then the latter should use `MustNot`. + const auto res = + parseVar(vk, /*isOptional=*/false, CreationPolicy::May, varID, didCreate); + if (!res.has_value() || failed(*res)) + return failure(); + return varID; +} + +FailureOr> +DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) { + VarInfo::ID id; + bool didCreate; + const auto res = + parseVar(vk, isOptional, CreationPolicy::Must, id, didCreate); + if (res.has_value()) { + FAILURE_IF_FAILED(*res) + return std::make_pair(env.bindVar(id), true); + } else { + return std::make_pair(env.bindUnusedVar(vk), false); + } +} + +//===----------------------------------------------------------------------===// +// `DimLvlMapParser` implementation for `DimLvlMap` per se. +//===----------------------------------------------------------------------===// + +FailureOr DimLvlMapParser::parseDimLvlMap() { + FAILURE_IF_FAILED(parseOptionalSymbolIdList()) + FAILURE_IF_FAILED(parseDimSpecList()) + FAILURE_IF_FAILED(parser.parseArrow()) + FAILURE_IF_FAILED(parseLvlSpecList()) + // TODO(wrengr): Try to improve the error messages from + // `VarEnv::emitErrorIfAnyUnbound`. + InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser); + if (failed(ifd)) + return ifd; + return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs); +} + +using Delimiter = mlir::OpAsmParser::Delimiter; + +ParseResult DimLvlMapParser::parseOptionalSymbolIdList() { + const auto parseSymVarBinding = [&]() -> ParseResult { + return ParseResult(parseVarBinding(VarKind::Symbol, /*isOptional=*/false)); + }; + // If I've correctly unpacked how exactly `Parser::parseCommaSeparatedList` + // handles the "optional" delimiters vs the non-optional ones, then + // the following call to `AsmParser::parseCommaSeparatedList` should + // be equivalent to the whole `AffineParse::parseOptionalSymbolIdList` + // method (which uses `Parser` methods to handle the optionality instead). + return parser.parseCommaSeparatedList(Delimiter::OptionalSquare, + parseSymVarBinding, " in symbol list"); +} + +//===----------------------------------------------------------------------===// +// `DimLvlMapParser` implementation for `DimSpec`. +//===----------------------------------------------------------------------===// + +ParseResult DimLvlMapParser::parseDimSpecList() { + return parser.parseCommaSeparatedList( + Delimiter::Paren, [&]() -> ParseResult { return parseDimSpec(); }, + " in dimension-specifier list"); +} + +ParseResult DimLvlMapParser::parseDimSpec() { + const auto res = parseVarBinding(VarKind::Dimension, /*isOptional=*/false); + FAILURE_IF_FAILED(res) + const DimVar var = res->first.cast(); + + DimExpr expr{AffineExpr()}; + if (succeeded(parser.parseOptionalEqual())) { + // FIXME(wrengr): I don't think there's any way to implement this + // without replicating the bulk of `AffineParser::parseAffineExpr` + // TODO(wrengr): Also, need to make sure the parser uses + // `parseVarUsage(VarKind::Level)` so that every `AffineDimExpr` + // necessarily corresponds to a `LvlVar` (never a `DimVar`). + // + // FIXME: proof of concept, parse trivial level vars (viz d0 = l0). + auto use = parseVarUsage(VarKind::Level); + FAILURE_IF_FAILED(use) + AffineExpr a = getAffineDimExpr(var.getNum(), parser.getContext()); + DimExpr dexpr{a}; + expr = dexpr; + } + + SparseTensorDimSliceAttr slice; + if (succeeded(parser.parseOptionalColon())) { + const auto loc = parser.getCurrentLocation(); + Attribute attr; + FAILURE_IF_FAILED(parser.parseAttribute(attr)) + slice = llvm::dyn_cast(attr); + ERROR_IF(!slice, "expected SparseTensorDimSliceAttr") + } + + dimSpecs.emplace_back(var, expr, slice); + return success(); +} + +//===----------------------------------------------------------------------===// +// `DimLvlMapParser` implementation for `LvlSpec`. +//===----------------------------------------------------------------------===// + +ParseResult DimLvlMapParser::parseLvlSpecList() { + return parser.parseCommaSeparatedList( + Delimiter::Paren, [&]() -> ParseResult { return parseLvlSpec(); }, + " in level-specifier list"); +} + +ParseResult DimLvlMapParser::parseLvlSpec() { + // FIXME(wrengr): This implementation isn't actually going to work as-is, + // due to grammar ambiguity. That is, assuming the current token is indeed + // a variable, we don't yet know whether that variable is supposed to + // be a binding vs being a usage that's part of the following AffineExpr. + // We can only disambiguate that by peeking at the next token to see whether + // it's the equals symbol or not. + // + // FIXME: proof of concept, assume it is new (viz. l0 = d0). + const auto res = parseVarBinding(VarKind::Level, /*isOptional=*/true); + FAILURE_IF_FAILED(res) + if (res->second) { + FAILURE_IF_FAILED(parser.parseEqual()) + } + const LvlVar var = res->first.cast(); + + // FIXME(wrengr): I don't think there's any way to implement this + // without replicating the bulk of `AffineParser::parseAffineExpr` + // + // TODO(wrengr): Also, need to make sure the parser uses + // `parseVarUsage(VarKind::Dimension)` so that every `AffineDimExpr` + // necessarily corresponds to a `DimVar` (never a `LvlVar`). + // + // FIXME: proof of concept, parse trivial dim vars (viz l0 = d0). + auto use = parseVarUsage(VarKind::Dimension); + FAILURE_IF_FAILED(use) + AffineExpr a = + getAffineDimExpr(env.toVar(*use).getNum(), parser.getContext()); + LvlExpr expr{a}; + + FAILURE_IF_FAILED(parser.parseColon()) + + const auto type = lvlTypeParser.parseLvlType(parser); + FAILURE_IF_FAILED(type) + + lvlSpecs.emplace_back(var, expr, *type); + return success(); +} + +//===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h @@ -0,0 +1,67 @@ +//===- LvlTypeParser.h - `DimLevelType` parser ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H +#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H + +#include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace sparse_tensor { +namespace ir_detail { + +//===----------------------------------------------------------------------===// +// These macros are for generating a C++ expression of type +// `std::initializer_list>` since there's +// no way to construct an object of that type directly via C++ code. +#define FOREVERY_LEVELTYPE(DO) \ + DO(DimLevelType::Dense) \ + DO(DimLevelType::Compressed) \ + DO(DimLevelType::CompressedNu) \ + DO(DimLevelType::CompressedNo) \ + DO(DimLevelType::CompressedNuNo) \ + DO(DimLevelType::Singleton) \ + DO(DimLevelType::SingletonNu) \ + DO(DimLevelType::SingletonNo) \ + DO(DimLevelType::SingletonNuNo) \ + DO(DimLevelType::CompressedWithHi) \ + DO(DimLevelType::CompressedWithHiNu) \ + DO(DimLevelType::CompressedWithHiNo) \ + DO(DimLevelType::CompressedWithHiNuNo) +#define LEVELTYPE_INITLIST_ELEMENT(lvlType) \ + std::make_pair(StringRef(toMLIRString(lvlType)), lvlType), +#define LEVELTYPE_INITLIST \ + { FOREVERY_LEVELTYPE(LEVELTYPE_INITLIST_ELEMENT) } + +// TODO(wrengr): Since this parser is non-trivial to construct, is there +// any way to hook into the parsing process so that we construct it only once +// at the begining of parsing and then destroy it once parsing has finished? +class LvlTypeParser { + const llvm::StringMap map; + +public: + explicit LvlTypeParser() : map(LEVELTYPE_INITLIST) {} +#undef LEVELTYPE_INITLIST +#undef LEVELTYPE_INITLIST_ELEMENT +#undef FOREVERY_LEVELTYPE + + std::optional lookup(StringRef str) const; + std::optional lookup(StringAttr str) const; + ParseResult parseLvlType(AsmParser &parser, DimLevelType &out) const; + FailureOr parseLvlType(AsmParser &parser) const; + // TODO(wrengr): `parseOptionalLvlType`? + // TODO(wrengr): `parseLvlTypeList`? +}; + +} // namespace ir_detail +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp @@ -0,0 +1,80 @@ +//===- LvlTypeParser.h - `DimLevelType` parser ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LvlTypeParser.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; +using namespace mlir::sparse_tensor::ir_detail; + +//===----------------------------------------------------------------------===// +// TODO(wrengr): rephrase these to do the trick for gobbling up any trailing +// semicolon +// +// NOTE: There's no way for `FAILURE_IF_FAILED` to simultaneously support +// both `OptionalParseResult` and `InFlightDiagnostic` return types. +// We can get the compiler to accept the code if we returned "`{}`", +// however for `OptionalParseResult` that would become the nullopt result, +// whereas for `InFlightDiagnostic` it would become a result that can +// be implicitly converted to success. By using "`failure()`" we ensure +// that `OptionalParseResult` behaves as intended, however that means the +// macro cannot be used for `InFlightDiagnostic` since there's no implicit +// conversion. +#define FAILURE_IF_FAILED(STMT) \ + if (failed(STMT)) { \ + return failure(); \ + } + +// Although `ERROR_IF` is phrased to return `InFlightDiagnostic`, that type +// can be implicitly converted to all four of `LogicalResult, `FailureOr`, +// `ParseResult`, and `OptionalParseResult`. (However, beware that the +// conversion to `OptionalParseResult` doesn't properly delegate to +// `InFlightDiagnostic::operator ParseResult`.) +// +// NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope. +#define ERROR_IF(COND, MSG) \ + if (COND) { \ + return parser.emitError(loc, MSG); \ + } + +//===----------------------------------------------------------------------===// +// `LvlTypeParser` implementation. +//===----------------------------------------------------------------------===// + +std::optional LvlTypeParser::lookup(StringRef str) const { + // NOTE: `StringMap::lookup` will return a default-constructed value if + // the key isn't found; which for enums means zero, and therefore makes + // it impossible to distinguish between actual zero-DimLevelType vs + // not-found. Whereas `StringMap::at` asserts that the key is found, + // which we don't want either. + const auto it = map.find(str); + return it == map.end() ? std::nullopt : std::make_optional(it->second); +} + +std::optional LvlTypeParser::lookup(StringAttr str) const { + return str ? lookup(str.getValue()) : std::nullopt; +} + +FailureOr LvlTypeParser::parseLvlType(AsmParser &parser) const { + DimLevelType out; + FAILURE_IF_FAILED(parseLvlType(parser, out)) + return out; +} + +ParseResult LvlTypeParser::parseLvlType(AsmParser &parser, + DimLevelType &out) const { + const auto loc = parser.getCurrentLocation(); + StringRef strVal; + FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal)); + const auto lvlType = lookup(strVal); + ERROR_IF(!lvlType, "unknown level-type '" + strVal + "'") + out = *lvlType; + return success(); +} + +//===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h @@ -0,0 +1,87 @@ +//===- TemplateExtras.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_TEMPLATEEXTRAS_H +#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_TEMPLATEEXTRAS_H + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace sparse_tensor { +namespace ir_detail { + +//===----------------------------------------------------------------------===// +// These two templates are like `AsmPrinter::{,detect_}has_print_method`, +// except they detect print methods taking `raw_ostream` (not `AsmPrinter`). +template +using has_print_method = + decltype(std::declval().print(std::declval())); +template +using detect_has_print_method = llvm::is_detected; +template +using enable_if_has_print_method = + std::enable_if_t::value, R>; + +/// Generic template for defining `operator<<` overloads which delegate +/// to `T::print(raw_ostream&) const`. Note that there's already another +/// generic template which defines `operator<<(AsmPrinterT&, T const&)` +/// via delegating to `operator<<(raw_ostream&, T const&)`. +template +inline enable_if_has_print_method +operator<<(llvm::raw_ostream &os, T const &t) { + t.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +/// Convert an enum to its underlying type. This template is designed +/// to avoid introducing implicit conversions to other integral types, +/// and is a backport of C++23 `std::to_underlying`. +template +constexpr std::underlying_type_t to_underlying(Enum e) noexcept { + return static_cast>(e); +} + +//===----------------------------------------------------------------------===// +template +static constexpr bool IsZeroCostAbstraction = + // These two predicates license the compiler to make several optimizations; + // some of which are explicitly documented by the C++ standard: + // + // + // However, some key optimizations aren't mentioned by the standard; e.g., + // that trivially-copyable enables passing-by-value, and the conjunction + // of trivially-copyable and trivially-destructible enables passing those + // values in registers rather than on the stack (cf., + // ). + std::is_trivially_copyable_v && std::is_trivially_destructible_v && + // This one helps ensure ABI compatibility (e.g., padding and alignment): + // + // + // In particular, the standard mentions that passing/returning a `struct` + // by value can sometimes introduce ABI overhead compared to using + // `enum class`; so this assertion is attempting to avoid that. + // + std::is_standard_layout_v && + // These two are what SmallVector uses to determine whether it can + // use memcpy. The commentary there mentions that it's intended to be + // an approximation of `is_trivially_copyable`, so this may be redundant + // with the above, but we include it just to make sure. + llvm::is_trivially_copy_constructible::value && + llvm::is_trivially_move_constructible::value; + +//===----------------------------------------------------------------------===// + +} // namespace ir_detail +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_TEMPLATEEXTRAS_H diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h @@ -0,0 +1,460 @@ +//===- Var.h ----------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H +#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H + +#include "TemplateExtras.h" + +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/EnumeratedArray.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace sparse_tensor { +namespace ir_detail { + +// Throughout this namespace we use the name `isWF` (is "well-formed") +// for predicates that detect intrinsic structural integrity criteria, +// and hence which should always be assertively true. Whereas we reserve +// the name `isValid` for predicates that detect extrinsic semantic +// integrity criteria, and hence which may legitimately return false even +// in well-formed programs. Moreover, "validity" is often a relational +// or contextual property, and therefore the same term may be considered +// valid in one context yet invalid in another. +// +// As an example of why we make this distinction, consider `Var`. +// A variable is well-formed if its kind and identifier are both well-formed; +// this can be checked locally, and the resulting truth-value holds globally. +// Whereas, a variable is valid with respect to a particular `Ranks` only if +// it is within bounds; and a variable is valid with respect to a particular +// `DimLvlMap` only if the variable is bound and all uses of the variable +// are within the scope of that binding. + +// Throughout this namespace we use `enum class` types to form "newtypes". +// The enum-based implementation of newtypes only serves to block implicit +// conversions; it cannot enforce any wellformedness constraints, since +// `enum class` permits using direct-list-initialization to construct +// arbitrary values[1]. Consequently, we use the syntax "`E{u}`" whenever +// we intend that ctor to be a noop (i.e., `std::is_same_v>`), since the compiler will ensure that that's +// the case. Whereas we only use the "`static_cast(u)`" syntax when we +// specifically intend to introduce conversions. +// +// [1]: +// + +//===----------------------------------------------------------------------===// +/// The three kinds of variables that `Var` can be. +/// +/// NOTE: The numerical values used to represent this enum should be +/// treated as an implementation detail, not as part of the API. In the +/// API below we use the canonical ordering `{Symbol,Dimension,Level}` even +/// though that does not agree with the numerical ordering of the numerical +/// representation. +enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 }; + +constexpr bool isWF(VarKind vk) { + const auto vk_ = to_underlying(vk); + return 0 <= vk_ && vk_ <= 2; +} + +/// Swaps `Dimension` and `Level`, but leaves `Symbol` the same. +constexpr VarKind flipVarKind(VarKind vk) { + return VarKind{2 - to_underlying(vk)}; +} +static_assert(flipVarKind(VarKind::Symbol) == VarKind::Symbol && + flipVarKind(VarKind::Dimension) == VarKind::Level && + flipVarKind(VarKind::Level) == VarKind::Dimension); + +/// Gets the ASCII character used as the prefix when printing `Var`. +constexpr char toChar(VarKind vk) { + // If `isWF(vk)` then this computation's intermediate results are always + // in the range [-44..126] (where that lower bound is under worst-case + // rearranging of the expression); and `int_fast8_t` is the fastest type + // which can support that range without over-/underflow. + const auto vk_ = static_cast(to_underlying(vk)); + return static_cast(100 + vk_ * (26 - vk_ * 11)); +} +static_assert(toChar(VarKind::Symbol) == 's' && + toChar(VarKind::Dimension) == 'd' && + toChar(VarKind::Level) == 'l'); + +//===----------------------------------------------------------------------===// +/// The type of arrays indexed by `VarKind`. +template +using VarKindArray = llvm::EnumeratedArray; + +//===----------------------------------------------------------------------===// +/// A concrete variable, to be used in our variant of `AffineExpr`. +class Var { +public: + /// Typedef to help disambiguate different uses of `unsigned`. + using Num = unsigned; + +private: + /// The underlying storage representation of `Var`. Note that this type + /// should be kept distinct from `Num`. Not only can they be different + /// C++ types (even though they currently happen to be the same), but + /// they also use different bitwise representations. + // + // FUTURE_CL(wrengr): Rather than rolling our own, we should + // consider using "llvm/ADT/Bitfields.h"; though that seems to only + // be used by LLVM for the sake of defining machine/assembly ops. + // Or we could consider abusing `PointerIntPair`... + using Impl = unsigned; + Impl impl; + + /// The largest `Var::Num` supported by `Var::Impl`. Two low-order + /// bits are reserved for storing the `VarKind`, and one high-order bit + /// is reserved for future use (e.g., to support `DenseMapInfo` while + /// maintaining the usual numeric values for "empty" and "tombstone"). + static constexpr Num kMaxNum = + static_cast(std::numeric_limits::max() >> 3); + +public: + // This must be public for `VarInfo` to use it (whereas we don't want + // to expose the `impl` field via friendship). + static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; } + + constexpr Var(VarKind vk, Num n) + : impl((static_cast(n) << 2) | + static_cast(to_underlying(vk))) { + assert(isWF(vk) && "unknown VarKind"); + assert(isWF_Num(n) && "Var::Num is too large"); + } + Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {} + Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {} + + constexpr bool operator==(Var other) const { return impl == other.impl; } + constexpr bool operator!=(Var other) const { return !(*this == other); } + + constexpr VarKind getKind() const { return static_cast(impl & 3); } + constexpr Num getNum() const { return static_cast(impl >> 2); } + + template constexpr bool isa() const; + template constexpr U cast() const; + template constexpr U dyn_cast() const; + + void print(llvm::raw_ostream &os) const; + void print(AsmPrinter &printer) const; + void dump() const; +}; +static_assert(IsZeroCostAbstraction); + +class SymVar final : public Var { +public: + static constexpr VarKind Kind = VarKind::Symbol; + static constexpr bool classof(Var const *var) { + return var->getKind() == Kind; + } + constexpr SymVar(Num sym) : Var(Kind, sym) {} + SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {} +}; +static_assert(IsZeroCostAbstraction); + +// TODO(wrengr): I'd like to give the ctors the types `DimVar(Dimension)` +// and `LvlVar(Level)`, instead of their current types using `Num`; +// however, that'd require importing "IR/SparseTensor.h" which nothing else +// in this file requires. Also beware the issues about implicit-conversion +// from `uint64_t` to `Num`. +class DimVar final : public Var { +public: + static constexpr VarKind Kind = VarKind::Dimension; + static constexpr bool classof(Var const *var) { + return var->getKind() == Kind; + } + constexpr DimVar(Num dim) : Var(Kind, dim) {} + DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {} +}; +static_assert(IsZeroCostAbstraction); + +class LvlVar final : public Var { +public: + static constexpr VarKind Kind = VarKind::Level; + static constexpr bool classof(Var const *var) { + return var->getKind() == Kind; + } + constexpr LvlVar(Num lvl) : Var(Kind, lvl) {} + LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {} +}; +static_assert(IsZeroCostAbstraction); + +// FIXME(wrengr): In order to get the `llvm::{isa,cast,dyn_cast}` +// free-functions to work (instead of using our hand-rolled methods), +// we'll need to define something like this: +// ``` +// namespace llvm { +// template struct CastInfo : OptionalValueCast {}; +// template <> struct ValueIsPresent { +// using UnwrappedType = Var; +// static inline bool isPresent(Var const&) { return true; } +// }; +// } // namespace llvm +// ``` +// The above will enable the type `llvm::dyn_cast(Var) -> std::optional`. +// +// FIXME(wrengr): The default `OptionalValueCast::doCast(Var const&)` +// implementation uses the expression "`U(var)`", which means that all the +// subclasses will need to define that upcasting-copy-ctor, and to ensure +// safety/correctness will need to mark that ctor as private/protected, +// which in turn means they'll need make the `CastInfo`/`OptionalValueCast` +// classes friends. +// +// We run into similar issues with our hand-rolled methods, the only +// difference is that the upcasting-copy-ctor would have type `U(Impl)` +// instead of `U(Var)` and that we'd need to make the `Var` class a friend +// rather than the `CastInfo`/`OptionalValueCast` classes. +// +template constexpr bool Var::isa() const { + if constexpr (std::is_same_v) + return getKind() == VarKind::Symbol; + if constexpr (std::is_same_v) + return getKind() == VarKind::Dimension; + if constexpr (std::is_same_v) + return getKind() == VarKind::Level; + // NOTE: The `AffineExpr::isa` implementation doesn't have a fallthrough + // case returning `false`; wrengr guesses that's so things will fail + // to compile whenever `!std::is_base_of`. Though it's unclear + // why they implemented it that way rather than using SFINAE for that, + // especially since it would give better error messages. +} + +template constexpr U Var::cast() const { + assert(isa()); + return U(impl >> 2); // NOTE TO Wren: confirm this fix +} + +template constexpr U Var::dyn_cast() const { + return isa() ? U(impl >> 2) : U(); +} + +//===----------------------------------------------------------------------===// +// Forward-decl so that we can declare methods of `Ranks` and `VarSet`. +class DimLvlExpr; + +//===----------------------------------------------------------------------===// +class Ranks final { + // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr. + // TODO(wrengr): to what extent do we actually care about constexpr here? + unsigned impl[3]; + + static constexpr unsigned to_index(VarKind vk) { + assert(isWF(vk) && "unknown VarKind"); + return static_cast(to_underlying(vk)); + } + +public: + // NOTE_TO_SELF(wrengr): According to + // we should be able to do this just fine, even though `constexpr` + constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank) + : impl() { + impl[to_index(VarKind::Symbol)] = symRank; + impl[to_index(VarKind::Dimension)] = dimRank; + impl[to_index(VarKind::Level)] = lvlRank; + } + Ranks(VarKindArray const &ranks) + : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension], + ranks[VarKind::Level]) {} + + constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; } + constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); } + constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); } + constexpr unsigned getLvlRank() const { return getRank(VarKind::Level); } + + constexpr bool isValid(Var var) const { + return var.getNum() < getRank(var.getKind()); + } + bool isValid(DimLvlExpr expr) const; +}; +static_assert(IsZeroCostAbstraction); + +//===----------------------------------------------------------------------===// +class VarSet final { + // If we're willing to give up the possibility of resizing the + // individual bitvectors, then we could flatten this into a single + // bitvector (akin to how `mlir::presburger::PresburgerSpace` does it); + // however, doing so would greatly complicate the implementation of the + // `occursIn(VarSet)` method. + VarKindArray impl; + +public: + explicit VarSet(Ranks const &ranks); + + // TODO(wrengr): can we come up with a single name that works for all three of + // these? + bool contains(Var var) const; + bool occursIn(VarSet const &vars) const; + bool occursIn(DimLvlExpr expr) const; + + void add(Var var); + // TODO(wrengr): void add(VarSet const& vars); + void add(DimLvlExpr expr); +}; + +//===----------------------------------------------------------------------===// +// TODO(wrengr): For good error messages we'll need to define something like: +// ```class LocatedVar final { llvm::SMLoc loc; VarInfo::ID id; };``` +// to be the actual thing occuring in our variant of AffineExpr. +// Though we may also want that struct to contain a pointer back to the +// `VarEnv` which contains the `VarInfo` for that `VarInfo::ID`. +// +// To go along with this, the `VarInfo` record should drop its own `SMLoc` +// field. + +//===----------------------------------------------------------------------===// +/// A record of metadata for/about a variable, used by `VarEnv`. +/// The principal goal of this record is to enable `VarEnv` to be used for +/// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to +/// remain unknown, since each record is instead identified by `VarInfo::ID`. +/// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever +/// order it likes, irrespective of the binding order (`Var::Num`) of the +/// associated variable. +class VarInfo final { +public: + /// Newtype for unique identifiers of `VarInfo` records, to ensure + /// they aren't confused with `Var::Num`. + enum class ID : unsigned {}; + +private: + // FUTURE_CL(wrengr): We could use the high-bit of `Var::Impl` to + // store the `std::optional` bit, therefore allowing us to bitbash the + // `num` and `kind` fields together. + // + StringRef name; // The bare-id used in the MLIR source. + llvm::SMLoc loc; // The location of the first occurence. + // TODO(wrengr): See the above `LocatedVar` note. + ID id; // The unique `VarInfo`-identifier. + std::optional num; // The unique `Var`-identifier (if resolved). + VarKind kind; // The kind of variable. + +public: + constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk, + std::optional n = {}) + : name(name), loc(loc), id(id), num(n), kind(vk) { + assert(!name.empty() && "null StringRef"); + assert(loc.isValid() && "null SMLoc"); + assert(isWF(vk) && "unknown VarKind"); + assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large"); + } + + constexpr StringRef getName() const { return name; } + constexpr llvm::SMLoc getLoc() const { return loc; } + Location getLocation(AsmParser &parser) const { + return parser.getEncodedSourceLoc(loc); + } + constexpr ID getID() const { return id; } + constexpr VarKind getKind() const { return kind; } + constexpr std::optional getNum() const { return num; } + constexpr bool hasNum() const { return num.has_value(); } + void setNum(Var::Num n); + constexpr Var getVar() const { + assert(hasNum()); + return Var(kind, *num); + } + constexpr std::optional tryGetVar() const { + return num ? std::make_optional(Var(kind, *num)) : std::nullopt; + } +}; +// We don't actually require this, since `VarInfo` is a proper struct +// rather than a newtype. But it passes, so for now we'll keep it around. +static_assert(IsZeroCostAbstraction); + +//===----------------------------------------------------------------------===// +enum class CreationPolicy { MustNot, May, Must }; + +class VarEnv final { + /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`. + VarKindArray nextNum; + /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects. + SmallVector vars; + /// Map from variable names to their `VarInfo::ID`. + llvm::StringMap ids; + + VarInfo::ID nextID() const { return static_cast(vars.size()); } + +public: + VarEnv() : nextNum(0) {} + + /// Gets the underlying storage for the `VarInfo` identified by + /// the `VarInfo::ID`. + /// + /// NOTE: The returned reference can become dangling if the `VarEnv` + /// object is mutated during the lifetime of the pointer. Therefore, + /// client code should not store the reference nor otherwise allow it + /// to live too long. + // + // FUTURE_CL(wrengr): Consider trying to define/use a nested class + // `struct{VarEnv*; VarInfo::ID}` akin to `BitVector::reference`. + VarInfo const &access(VarInfo::ID id) const { + // `SmallVector::operator[]` already asserts the index is in-bounds. + return vars[to_underlying(id)]; + } + VarInfo const *access(std::optional oid) const { + return oid ? &access(*oid) : nullptr; + } + + Var toVar(VarInfo::ID id) const { return vars[to_underlying(id)].getVar(); } + +private: + VarInfo &access(VarInfo::ID id) { + return const_cast(std::as_const(*this).access(id)); + } + VarInfo *access(std::optional oid) { + return const_cast(std::as_const(*this).access(oid)); + } + +public: + /// Attempts to look up the variable with the given name. + std::optional lookup(StringRef name) const; + + /// Attempts to create a new currently-unbound variable. When a variable + /// of that name already exists: if `verifyUsage` is true, then will assert + /// that the variable has the same kind and a consistent location; otherwise, + /// when `verifyUsage` is false, this is a noop. Returns the identifier + /// for the variable with the given name (i.e., either the newly created + /// variable, or the pre-existing variable), and a bool indicating whether + /// a new variable was created. + std::pair create(StringRef name, llvm::SMLoc loc, + VarKind vk, bool verifyUsage = false); + + /// Attempts to lookup or create a variable according to the given + /// `CreationPolicy`. Returns nullopt in one of two circumstances: + /// (1) the policy says we `Must` create, yet the variable already exists; + /// (2) the policy says we `MustNot` create, yet no such variable exists. + /// Otherwise, if the variable already exists then it is validated against + /// the given kind and location to ensure consistency. + // + // TODO(wrengr): Define an enum of error codes, to avoid `nullopt`-blindness + // TODO(wrengr): Prolly want to rename this to `create` and move the + // current method of that name to being a private `createImpl`. + std::optional> + lookupOrCreate(CreationPolicy policy, StringRef name, llvm::SMLoc loc, + VarKind vk); + + /// Binds the given variable to the next free `Var::Num` for its `VarKind`. + Var bindVar(VarInfo::ID id); + + /// Creates a new variable of the given kind and immediately binds it. + /// This should only be used whenever the variable is known to be unused + /// and therefore does not have a name. + Var bindUnusedVar(VarKind vk); + + InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const; + + Ranks getRanks() const { return Ranks(nextNum); } +}; + +//===----------------------------------------------------------------------===// + +} // namespace ir_detail +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -0,0 +1,292 @@ +//===- Var.cpp ------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Var.h" +#include "DimLvlMap.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; +using namespace mlir::sparse_tensor::ir_detail; + +//===----------------------------------------------------------------------===// +// `Var` implementation. +//===----------------------------------------------------------------------===// + +void Var::print(AsmPrinter &printer) const { print(printer.getStream()); } + +void Var::print(llvm::raw_ostream &os) const { + os << toChar(getKind()) << getNum(); +} + +void Var::dump() const { + print(llvm::errs()); + llvm::errs() << "\n"; +} + +//===----------------------------------------------------------------------===// +// `Ranks` implementation. +//===----------------------------------------------------------------------===// + +bool Ranks::isValid(DimLvlExpr expr) const { + // FIXME(wrengr): we have cases without affine expr at an early point + if (!expr.getAffineExpr()) + return true; + // Each `DimLvlExpr` only allows one kind of non-symbol variable. + int64_t maxSym = -1, maxVar = -1; + // TODO(wrengr): If we run into ASan issues, that may be due to the + // "`{{...}}`" syntax; so we may want to try using local-variables instead. + mlir::getMaxDimAndSymbol>({{expr.getAffineExpr()}}, + maxVar, maxSym); + // TODO(wrengr): We may want to add a call to `LLVM_DEBUG` like + // `willBeValidAffineMap` does. And/or should return `InFlightDiagnostic` + // instead of bool. + return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind()); +} + +//===----------------------------------------------------------------------===// +// `VarSet` implementation. +//===----------------------------------------------------------------------===// + +static constexpr const VarKind everyVarKind[] = { + VarKind::Dimension, VarKind::Symbol, VarKind::Level}; + +VarSet::VarSet(Ranks const &ranks) { + // FIXME(wrengr): will this DWIM, or do we need to worry about + // `reserve` causing resizing/dangling issues? + for (const auto vk : everyVarKind) + impl[vk].reserve(ranks.getRank(vk)); +} + +bool VarSet::contains(Var var) const { + // FIXME(wrengr): this implementation will raise assertion failure on OOB; + // but perhaps we'd rather have this return false on OOB? That's + // necessary for consistency with the `anyCommon` implementation of + // `occursIn(VarSet)`. + const llvm::SmallBitVector &bits = impl[var.getKind()]; + // NOTE TO Wren: did this to avoid OOB but perhaps it is result of bug + if (var.getNum() >= bits.size()) + return false; + return bits[var.getNum()]; +} + +bool VarSet::occursIn(VarSet const &other) const { + for (const auto vk : everyVarKind) + if (impl[vk].anyCommon(other.impl[vk])) + return true; + return false; +} + +bool VarSet::occursIn(DimLvlExpr expr) const { + if (!expr) + return false; + switch (expr.getAffineKind()) { + case AffineExprKind::Constant: + return false; + case AffineExprKind::SymbolId: + return contains(expr.castSymVar()); + case AffineExprKind::DimId: + return contains(expr.castDimLvlVar()); + case AffineExprKind::Add: + case AffineExprKind::Mul: + case AffineExprKind::Mod: + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + const auto [lhs, op, rhs] = expr.unpackBinop(); + (void)op; + return occursIn(lhs) || occursIn(rhs); + } + } + llvm_unreachable("unknown AffineExprKind"); +} + +void VarSet::add(Var var) { + // FIXME(wrengr): this implementation will raise assertion failure on OOB; + // but perhaps we'd rather have this be a noop on OOB? or to grow + // the underlying bitvectors on OOB? + impl[var.getKind()][var.getNum()] = true; +} + +// TODO(wrengr): void VarSet::add(VarSet const& other); + +void VarSet::add(DimLvlExpr expr) { + if (!expr) + return; + switch (expr.getAffineKind()) { + case AffineExprKind::Constant: + return; + case AffineExprKind::SymbolId: + add(expr.castSymVar()); + return; + case AffineExprKind::DimId: + add(expr.castDimLvlVar()); + return; + case AffineExprKind::Add: + case AffineExprKind::Mul: + case AffineExprKind::Mod: + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + const auto [lhs, op, rhs] = expr.unpackBinop(); + (void)op; + add(lhs); + add(rhs); + return; + } + } + llvm_unreachable("unknown AffineExprKind"); +} + +//===----------------------------------------------------------------------===// +// `VarInfo` implementation. +//===----------------------------------------------------------------------===// + +void VarInfo::setNum(Var::Num n) { + assert(!hasNum() && "Var::Num is already set"); + assert(Var::isWF_Num(n) && "Var::Num is too large"); + num = n; +} + +//===----------------------------------------------------------------------===// +// `VarEnv` implementation. +//===----------------------------------------------------------------------===// + +/// Helper function for `assertUsageConsistency` to better handle SMLoc +/// mismatches. +// TODO(wrengr): If we switch to the `LocatedVar` design, then there's +// no need for anything like `minSMLoc` since `assertUsageConsistency` +// won't need to do anything about locations. +LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc +minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { + const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast(); + assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); + const auto loc2 = parser.getEncodedSourceLoc(sm2).dyn_cast(); + assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`"); + if (loc1.getFilename() != loc2.getFilename()) + return SMLoc(); + const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn()); + const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn()); + return pair1 <= pair2 ? sm1 : sm2; +} + +LLVM_ATTRIBUTE_UNUSED static void +assertInternalConsistency(VarEnv const &env, VarInfo::ID id, StringRef name) { +#ifndef NDEBUG + const auto &var = env.access(id); + assert(var.getName() == name && "found inconsistent name"); + assert(var.getID() == id && "found inconsistent VarInfo::ID"); +#endif // NDEBUG +} + +// NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc` +// (or find some other way to convert SMLoc to FileLineColLoc), then this +// would no longer be `const VarEnv` (and couldn't be a free-function either). +LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env, + VarInfo::ID id, + llvm::SMLoc loc, + VarKind vk) { +#ifndef NDEBUG + const auto &var = env.access(id); + assert(var.getKind() == vk && + "a variable of that name already exists with a different VarKind"); + // Since the same variable can occur at several locations, + // it would not be appropriate to do `assert(var.getLoc() == loc)`. + /* TODO(wrengr): + const auto minLoc = minSMLoc(_, var.getLoc(), loc); + assert(minLoc && "Location mismatch/incompatibility"); + var.loc = minLoc; + // */ +#endif // NDEBUG +} + +std::optional VarEnv::lookup(StringRef name) const { + // NOTE: `StringMap::lookup` will return a default-constructed value if + // the key isn't found; which for enums means zero, and therefore makes + // it impossible to distinguish between actual zero-VarInfo::ID vs not-found. + // Whereas `StringMap::at` asserts that the key is found, which we don't + // want either. + const auto iter = ids.find(name); + if (iter == ids.end()) + return std::nullopt; + const auto id = iter->second; +#ifndef NDEBUG + assertInternalConsistency(*this, id, name); +#endif // NDEBUG + return id; +} + +std::pair VarEnv::create(StringRef name, llvm::SMLoc loc, + VarKind vk, bool verifyUsage) { + const auto &[iter, didInsert] = ids.try_emplace(name, nextID()); + const auto id = iter->second; + if (didInsert) { + vars.emplace_back(id, name, loc, vk); + } else { +#ifndef NDEBUG + assertInternalConsistency(*this, id, name); + if (verifyUsage) + assertUsageConsistency(*this, id, loc, vk); +#endif // NDEBUG + } + return std::make_pair(id, didInsert); +} + +std::optional> +VarEnv::lookupOrCreate(CreationPolicy policy, StringRef name, llvm::SMLoc loc, + VarKind vk) { + switch (policy) { + case CreationPolicy::MustNot: { + const auto oid = lookup(name); + if (!oid) + return std::nullopt; // Doesn't exist, but must not create. +#ifndef NDEBUG + assertUsageConsistency(*this, *oid, loc, vk); +#endif // NDEBUG + return std::make_pair(*oid, false); + } + case CreationPolicy::May: + return create(name, loc, vk, /*verifyUsage=*/true); + case CreationPolicy::Must: { + const auto res = create(name, loc, vk, /*verifyUsage=*/false); + // const auto id = res.first; + const auto didCreate = res.second; + if (!didCreate) + return std::nullopt; // Already exists, but must create. + return res; + } + } + llvm_unreachable("unknown CreationPolicy"); +} + +Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); } +Var VarEnv::bindVar(VarInfo::ID id) { + auto &info = access(id); + const auto var = bindUnusedVar(info.getKind()); + // NOTE: `setNum` already checks wellformedness of the `Var::Num`. + info.setNum(var.getNum()); + return var; +} + +// TODO(wrengr): Alternatively there's `mlir::emitError(Location, Twine const&)` +// which is what `Operation::emitError` uses; though I'm not sure if +// that's appropriate to use here... But if it is, then that means +// we can have `VarInfo` store `Location` rather than `SMLoc`, which +// means we can use `FusedLoc` to handle the combination issue in +// `VarEnv::lookupOrCreate`. +// +// TODO(wrengr): is there any way to combine multiple IFDs, so that +// we can report all unbound variables instead of just the first one +// encountered? +// +InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const { + for (const auto &var : vars) + if (!var.hasNum()) + return parser.emitError(var.getLoc(), + "Unbound variable: " + var.getName()); + return {}; +} + +//===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -8,6 +8,8 @@ #include +#include "Detail/DimLvlMapParser.h" + #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" @@ -449,8 +451,8 @@ StringRef attrName; // Exactly 6 keys. - SmallVector keys = {"lvlTypes", "dimToLvl", "posWidth", - "crdWidth", "dimSlices"}; + SmallVector keys = {"lvlTypes", "dimToLvl", "posWidth", + "crdWidth", "dimSlices", "NEW_SYNTAX"}; while (succeeded(parser.parseOptionalKeyword(&attrName))) { if (!llvm::is_contained(keys, attrName)) { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; @@ -514,6 +516,20 @@ if (!finished) return {}; RETURN_ON_FAIL(parser.parseRSquare()) + } else if (attrName == "NEW_SYNTAX") { + // Note that we are in the process of migrating to a new STEA surface + // syntax. While this is ongoing we use the temporary "NEW_SYNTAX = ...." + // to switch to the new parser. This allows us to gradually migrate + // examples over to the new surface syntax before making the complete + // switch once work is completed. + // TODO: replace everything here with new STEA surface syntax parser + ir_detail::DimLvlMapParser cParser(parser); + auto res = cParser.parseDimLvlMap(); + RETURN_ON_FAIL(res); + // Proof of concept result. + // TODO: use DimLvlMap directly as storage representation + for (unsigned i = 0, e = res->getLvlRank(); i < e; i++) + lvlTypes.push_back(res->getDimLevelType(i)); } // Only the last item can omit the comma diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -128,3 +128,20 @@ // CHECK-LABEL: func private @sparse_slice( // CHECK-SAME: tensor> func.func private @sparse_slice(tensor) + +// ----- + +// Migration plan for new STEA surface syntax, +// use the NEW_SYNTAX on selected examples +// and then TODO: remove when fully migrated + +#NewSurfaceSyntax = #sparse_tensor.encoding<{ + NEW_SYNTAX = + (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) +}> + +// CHECK-LABEL: func private @foo( +// CHECK-SAME: tensor> +func.func private @foo(%arg0: tensor) { + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2557,7 +2557,18 @@ cc_library( name = "SparseTensorDialect", - srcs = ["lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp"], + srcs = [ + "lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp", + "lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h", + "lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp", + "lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h", + "lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp", + "lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h", + "lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h", + "lib/Dialect/SparseTensor/IR/Detail/Var.cpp", + "lib/Dialect/SparseTensor/IR/Detail/Var.h", + "lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp", + ], hdrs = [ "include/mlir/Dialect/SparseTensor/IR/SparseTensor.h", "include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h",