diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -169,6 +169,10 @@ /// Returns the type for index storage based on indexBitWidth Type getIndexType() const; + + /// Constructs a new encoding with the dimOrdering and higherOrdering + /// reset to the default/identity. + SparseTensorEncodingAttr withoutOrdering() const; }]; let genVerifyDecl = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -57,6 +57,12 @@ return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType; } +SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const { + return SparseTensorEncodingAttr::get( + getContext(), getDimLevelType(), AffineMap(), AffineMap(), + getPointerBitWidth(), getIndexBitWidth()); +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -529,9 +529,7 @@ assert(elemTp == dstTp.getElementType() && "reshape should not change element type"); // Start an iterator over the source tensor (in original index order). - auto noPerm = SparseTensorEncodingAttr::get( - op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(), - encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + const auto noPerm = encSrc.withoutOrdering(); SmallVector srcDimSizes = getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc()); NewCallParams params(rewriter, loc); @@ -596,9 +594,7 @@ Type elemTp = tensorTp.getElementType(); // Start an iterator over the tensor (in original index order). - auto noPerm = SparseTensorEncodingAttr::get( - rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(), - enc.getPointerBitWidth(), enc.getIndexBitWidth()); + const auto noPerm = enc.withoutOrdering(); SmallVector dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t); Value iter = NewCallParams(rewriter, loc) .genBuffers(noPerm, dimSizes, tensorTp) @@ -1485,9 +1481,7 @@ auto encSrc = getSparseTensorEncoding(srcType); SmallVector dimSizes = getDimSizes(rewriter, loc, encSrc, srcType, src); - auto enc = SparseTensorEncodingAttr::get( - op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(), - encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + const auto enc = encSrc.withoutOrdering(); Value coo = NewCallParams(rewriter, loc) .genBuffers(enc, dimSizes, srcType) .genNewCall(Action::kToCOO, src);