diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h @@ -294,6 +294,7 @@ unsigned symRank; SmallVector dimSpecs; SmallVector lvlSpecs; + bool mustPrintLvlVars; // Checks for integrity of variable-binding structure. // This is already called by the ctor. diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -249,7 +249,8 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef dimSpecs, ArrayRef lvlSpecs) - : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs) { + : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs), + mustPrintLvlVars(false) { // First, check integrity of the variable-binding structure. // NOTE: This establishes the invariant that calls to `VarSet::add` // below cannot cause OOB errors. @@ -274,8 +275,14 @@ for (const auto &dimSpec : dimSpecs) if (!dimSpec.canElideExpr()) usedVars.add(dimSpec.getExpr()); - for (auto &lvlSpec : this->lvlSpecs) - lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar())); + for (auto &lvlSpec : this->lvlSpecs) { + // Is this LvlVar used in any overt expression? + const bool isUsed = usedVars.contains(lvlSpec.getBoundVar()); + // This LvlVar can be elided iff it isn't overtly used. + lvlSpec.setElideVar(!isUsed); + // If any LvlVar cannot be elided, then must forward-declare all LvlVars. + mustPrintLvlVars = mustPrintLvlVars || isUsed; + } } bool DimLvlMap::isWF() const { @@ -314,12 +321,21 @@ os << ']'; } + // LvlVar forward-declarations. + if (mustPrintLvlVars) { + os << '{'; + llvm::interleaveComma( + lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); }); + os << '}'; + } + // Dimension specifiers. os << '('; llvm::interleaveComma( dimSpecs, os, [&](DimSpec const &spec) { spec.print(os, wantElision); }); os << ") -> ("; // Level specifiers. + wantElision = wantElision && !mustPrintLvlVars; llvm::interleaveComma( lvlSpecs, os, [&](LvlSpec const &spec) { spec.print(os, wantElision); }); os << ')';