diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -85,6 +85,22 @@ return constantI32(rewriter, loc, static_cast(sec)); } +/// Generates a constant of the internal type encoding for pointer +/// overhead storage. +static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, + Location loc, + SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); +} + +/// Generates a constant of the internal type encoding for index overhead +/// storage. +static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, + Location loc, + SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); +} + /// Generates a constant of the internal type encoding for primary storage. static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, Type tp) { @@ -277,10 +293,8 @@ params.push_back(genBuffer(rewriter, loc, rev)); // Secondary and primary types encoding. ShapedType resType = op->getResult(0).getType().cast(); - params.push_back( - constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth())); - params.push_back( - constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth())); + params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); + params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); params.push_back( constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType())); // User action and pointer. @@ -598,10 +612,8 @@ encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); Value coo = genNewCall(rewriter, op, params); - params[3] = constantOverheadTypeEncoding(rewriter, loc, - encDst.getPointerBitWidth()); - params[4] = constantOverheadTypeEncoding(rewriter, loc, - encDst.getIndexBitWidth()); + params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); + params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); params[6] = constantAction(rewriter, loc, Action::kFromCOO); params[7] = coo; rewriter.replaceOp(op, genNewCall(rewriter, op, params));