diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -144,6 +144,10 @@ #undef DEPRECATED +namespace detail { +Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth); +} // namespace detail + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -234,6 +234,16 @@ return enc ? enc.getPointerBitWidth() : 0; } + /// Returns the index-overhead MLIR type, defaulting to `IndexType`. + Type getIndexType() const { + return detail::getIntegerOrIndexType(getContext(), getIndexBitWidth()); + } + + /// Returns the pointer-overhead MLIR type, defaulting to `IndexType`. + Type getPointerType() const { + return detail::getIntegerOrIndexType(getContext(), getPointerBitWidth()); + } + private: // These two must be const, to ensure coherence of the memoized fields. const RankedTensorType rtp; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -114,18 +114,19 @@ << "expect positive value or ? for slice offset/size/stride"; } -static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) { +Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx, + unsigned bitwidth) { if (bitwidth) return IntegerType::get(ctx, bitwidth); return IndexType::get(ctx); } Type SparseTensorEncodingAttr::getPointerType() const { - return getIntegerOrIndexType(getContext(), getPointerBitWidth()); + return detail::getIntegerOrIndexType(getContext(), getPointerBitWidth()); } Type SparseTensorEncodingAttr::getIndexType() const { - return getIntegerOrIndexType(getContext(), getIndexBitWidth()); + return detail::getIntegerOrIndexType(getContext(), getIndexBitWidth()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -160,7 +160,7 @@ // Append linear x pointers, initialized to zero. Since each compressed // dimension initially already has a single zero entry, this maintains // the desired "linear + 1" length property at all times. - Type ptrType = stt.getEncoding().getPointerType(); + Type ptrType = stt.getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l, ptrZero, linear); @@ -279,8 +279,7 @@ // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. - Value ptrZero = - constantZero(builder, loc, stt.getEncoding().getPointerType()); + Value ptrZero = constantZero(builder, loc, stt.getPointerType()); for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) { // Fills dim sizes array. // FIXME: this method seems to set *level* sizes, but the name is confusing @@ -546,7 +545,7 @@ // times? // if (l > 0) { - Type ptrType = stt.getEncoding().getPointerType(); + Type ptrType = stt.getPointerType(); Value ptrMemRef = desc.getPtrMemRef(l); Value hi = desc.getPtrMemSize(builder, loc, l); Value zero = constantIndex(builder, loc, 0); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -179,14 +179,13 @@ llvm::function_ref callback) { - const auto enc = stt.getEncoding(); - assert(enc); + assert(stt.hasEncoding()); // Construct the basic types. - Type idxType = enc.getIndexType(); - Type ptrType = enc.getPointerType(); + Type idxType = stt.getIndexType(); + Type ptrType = stt.getPointerType(); Type eltType = stt.getElementType(); - Type metaDataType = StorageSpecifierType::get(enc); + Type metaDataType = StorageSpecifierType::get(stt.getEncoding()); // memref pointers Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); // memref indices @@ -195,7 +194,7 @@ Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); foreachFieldInSparseTensor( - enc, + stt.getEncoding(), [metaDataType, ptrMemType, idxMemType, valMemType, callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind, Level lvl, DimLevelType dlt) -> bool {