diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -38,6 +38,12 @@ /// Converts the internal type-encoding for overhead storage to an mlir::Type. Type getOverheadType(Builder &builder, OverheadType ot); +/// Returns the OverheadType for pointer overhead storage. +OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); + +/// Returns the OverheadType for index overhead storage. +OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); + /// Returns the mlir::Type for pointer overhead storage. Type getPointerOverheadType(Builder &builder, const SparseTensorEncodingAttr &enc); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -58,15 +58,24 @@ llvm_unreachable("Unknown OverheadType"); } +OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding( + const SparseTensorEncodingAttr &enc) { + return overheadTypeEncoding(enc.getPointerBitWidth()); +} + +OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding( + const SparseTensorEncodingAttr &enc) { + return overheadTypeEncoding(enc.getIndexBitWidth()); +} + Type mlir::sparse_tensor::getPointerOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { - return getOverheadType(builder, - overheadTypeEncoding(enc.getPointerBitWidth())); + return getOverheadType(builder, pointerOverheadTypeEncoding(enc)); } Type mlir::sparse_tensor::getIndexOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { - return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); + return getOverheadType(builder, indexOverheadTypeEncoding(enc)); } StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {