diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -85,16 +85,25 @@ namespace mlir { namespace sparse_tensor { +// NOTE: `Value::getType` doesn't check for null before trying to +// dereference things. Therefore we check, because an assertion-failure +// is easier to debug than a segfault. Presumably other `T::getType` +// methods are similarly susceptible. + /// Convenience method to abbreviate casting `getType()`. template -inline RankedTensorType getRankedTensorType(T t) { - return t.getType().template cast(); +inline RankedTensorType getRankedTensorType(T &&t) { + assert(static_cast(std::forward(t)) && + "getRankedTensorType got null argument"); + return std::forward(t).getType().template cast(); } /// Convenience method to abbreviate casting `getType()`. template -inline MemRefType getMemRefType(T t) { - return t.getType().template cast(); +inline MemRefType getMemRefType(T &&t) { + assert(static_cast(std::forward(t)) && + "getMemRefType got null argument"); + return std::forward(t).getType().template cast(); } /// Convenience method to get a sparse encoding attribute from a type. diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -51,6 +51,7 @@ : rtp(rtp), enc(getSparseTensorEncoding(rtp)), lvlRank(enc ? enc.getLvlRank() : getDimRank()), dim2lvl(enc.hasIdDimOrdering() ? AffineMap() : enc.getDimOrdering()) { + assert(rtp && "got null RankedTensorType"); assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch"); }