diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h @@ -290,16 +290,6 @@ //===----------------------------------------------------------------------===// class DimLvlMap final { - // TODO(wrengr): Need to define getters - unsigned symRank; - SmallVector dimSpecs; - SmallVector lvlSpecs; - bool mustPrintLvlVars; - - // Checks for integrity of variable-binding structure. - // This is already called by the ctor. - [[nodiscard]] bool isWF() const; - public: DimLvlMap(unsigned symRank, ArrayRef dimSpecs, ArrayRef lvlSpecs); @@ -310,11 +300,41 @@ unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); } Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; } - DimLevelType getDimLevelType(unsigned i) { return lvlSpecs[i].getType(); } + ArrayRef getDims() const { return dimSpecs; } + const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; } + SparseTensorDimSliceAttr getDimSlice(Dimension dim) const { + return getDim(dim).getSlice(); + } + + ArrayRef getLvls() const { return lvlSpecs; } + const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; } + DimLevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); } + + AffineMap getDimToLvlMap(MLIRContext *context) const; + AffineMap getLvlToDimMap(MLIRContext *context) const; void print(llvm::raw_ostream &os, bool wantElision = true) const; void print(AsmPrinter &printer, bool wantElision = true) const; void dump() const; + +private: + /// Checks for integrity of variable-binding structure. + /// This is already called by the ctor. + [[nodiscard]] bool isWF() const; + + /// Helper function to call `DimSpec::setExpr` while asserting that + /// the invariant established by `DimLvlMap:isWF` is maintained. + /// This is used by the ctor. + void setDimExpr(Dimension dim, DimExpr expr) { + assert(expr && getRanks().isValid(expr)); + dimSpecs[dim].setExpr(expr); + } + + // All these fields are const-after-ctor. + unsigned symRank; + SmallVector dimSpecs; + SmallVector lvlSpecs; + bool mustPrintLvlVars; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -262,10 +262,8 @@ // needs to happen before the code for setting every `LvlSpec::elideVar`, // since if the LvlVar is only used in elided DimExpr, then the // LvlVar should also be elided. - // NOTE: Whenever we set a new DimExpr, we must make sure to validate it - // against our ranks, to restore the invariant established by `isWF` above. - // TODO(wrengr): We might should adjust the `DimLvlExpr` ctor to take a - // `Ranks` argument and perform the validation then. + // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs, + // to ensure that we maintain the invariant established by `isWF` above. // Third, we set every `LvlSpec::elideVar` according to whether that // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr). @@ -300,6 +298,22 @@ return true; } +AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const { + SmallVector lvlAffines; + lvlAffines.reserve(getLvlRank()); + for (const auto &lvlSpec : lvlSpecs) + lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr()); + return AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context); +} + +AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const { + SmallVector dimAffines; + dimAffines.reserve(getDimRank()); + for (const auto &dimSpec : dimSpecs) + dimAffines.push_back(dimSpec.getExpr().getAffineExpr()); + return AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context); +} + void DimLvlMap::dump() const { print(llvm::errs(), /*wantElision=*/false); llvm::errs() << "\n"; diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -115,7 +115,7 @@ } void VarSet::add(Var var) { - // NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB. + // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB. impl[var.getKind()][var.getNum()] = true; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -530,8 +530,8 @@ RETURN_ON_FAIL(res); // Proof of concept result. // TODO: use DimLvlMap directly as storage representation - for (unsigned i = 0, e = res->getLvlRank(); i < e; i++) - lvlTypes.push_back(res->getDimLevelType(i)); + for (Level lvl = 0, lvlRank = res->getLvlRank(); lvl < lvlRank; lvl++) + lvlTypes.push_back(res->getLvlType(lvl)); } // Only the last item can omit the comma