diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1235,29 +1235,28 @@ matchAndRewrite(PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - const auto rtp = getRankedTensorType(op.getResult()); - assert(isUniqueCOOType(rtp)); + const auto stt = getSparseTensorType(op.getResult()); + assert(isUniqueCOOType(stt)); SmallVector fields; Location loc = op.getLoc(); foreachFieldAndTypeInSparseTensor( - rtp, - [&rewriter, &fields, &op, rtp, + stt, + [&rewriter, &fields, &op, stt, loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); - auto enc = getSparseTensorEncoding(rtp); Value field; switch (fKind) { case SparseTensorFieldKind::StorageSpec: - field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp); + field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt); break; case SparseTensorFieldKind::PosMemRef: { // TACO-style COO starts with a PosBuffer // By creating a constant value for it, we avoid the complexity of // memory management. - const auto posTp = enc.getPosType(); + const auto posTp = stt.getPosType(); auto tensorType = RankedTensorType::get({2}, posTp); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); @@ -1306,13 +1305,11 @@ return true; }); - MutSparseTensorDescriptor desc(rtp, fields); + MutSparseTensorDescriptor desc(stt, fields); auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0); - // FIXME: should use `SparseTensorType::getLvlRank` in lieu of - // `RankedTensorType::getRank`, because the latter introduces dim/lvl - // ambiguity. - for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) { - const auto sh = rtp.getShape()[lvl]; + for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { + // FIXME: dim/lvl confusion! + const auto sh = stt.getDimShape()[lvl]; assert(!ShapedType::isDynamic(sh)); desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh)); if (lvl == 0)