diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -356,16 +356,10 @@ PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value srcTensor = op.getSrc(); - auto srcTp = getRankedTensorType(srcTensor); - auto dstTp = getRankedTensorType(op.getResult()); - - SparseTensorType srcStt(srcTp); - SparseTensorType dstStt(dstTp); - - const auto encSrc = srcStt.getEncoding(); - if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) { + const auto srcTp = getSparseTensorType(srcTensor); + const auto dstTp = getSparseTensorType(op.getResult()); + if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) return failure(); - } // Generate code to represent the static dimension constants or compute // the dynamic dimension values. @@ -373,11 +367,11 @@ sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); SmallVector dstSizes; SmallVector dstDynSizes; - if (dstTp.hasStaticShape()) { - for (auto d : dstTp.getShape()) + if (dstTp.hasStaticDimShape()) { + for (Dimension d : dstTp.getDimShape()) dstSizes.push_back(constantIndex(rewriter, loc, d)); } else { - ArrayRef dstShape = dstTp.getShape(); + ArrayRef dstShape = dstTp.getDimShape(); genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape, op.getReassociationIndices()); for (auto [idx, shape] : llvm::enumerate(dstShape)) { @@ -389,8 +383,8 @@ // Only need a unordered COO buffer if input and output are not sorted // in the same way. Type bufferTp = - srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity() - ? dstTp + srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity() + ? dstTp.getRankedTensorType() : getUnorderedCOOFromType(dstTp); Value buffer = @@ -406,11 +400,12 @@ // followed by an optional // %t = sparse_tensor.cast %tmp // depending on whether the input/output are sorted in the same way. + const auto encSrc = srcTp.getEncoding(); ForeachOp foreachOp = rewriter.create( loc, srcTensor, buffer, [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, ValueRange reduc) { - const Dimension dimRank = srcTp.getRank(); + const Dimension dimRank = srcTp.getDimRank(); SmallVector srcDcvs; srcDcvs.reserve(dimRank); for (Dimension d = 0; d < dimRank; d++) { @@ -427,7 +422,8 @@ Value t = rewriter.create(loc, foreachOp.getResult(0), true); if (bufferTp != dstTp) { - Value converted = rewriter.create(loc, dstTp, t).getResult(); + auto dstRTT = dstTp.getRankedTensorType(); + Value converted = rewriter.create(loc, dstRTT, t).getResult(); rewriter.create(loc, t); t = converted; }