diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -298,10 +298,14 @@ ]; let extraClassDeclaration = [{ - /// Returns the type for position storage based on posWidth + /// Returns the type for position storage based on posWidth. + /// Asserts that the encoding is non-null (since there's nowhere + /// to get the `MLIRContext` from). Type getPosType() const; - /// Returns the type for coordinate storage based on crdWidth + /// Returns the type for coordinate storage based on crdWidth. + /// Asserts that the encoding is non-null (since there's nowhere + /// to get the `MLIRContext` from). Type getCrdType() const; /// Constructs a new encoding with the dimOrdering and higherOrdering diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -122,10 +122,12 @@ } Type SparseTensorEncodingAttr::getPosType() const { + assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return detail::getIntegerOrIndexType(getContext(), getPosWidth()); } Type SparseTensorEncodingAttr::getCrdType() const { + assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return detail::getIntegerOrIndexType(getContext(), getCrdWidth()); }