diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -116,6 +116,17 @@ RankedTensorType getCOOFromType(RankedTensorType src, bool ordered); +/// Returns true iff MLIR operand has any sparse operand or result. +inline bool hasAnySparseOperandOrResult(Operation *op) { + bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) { + return getSparseTensorEncoding(t) != nullptr; + }); + bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) { + return getSparseTensorEncoding(t) != nullptr; + }); + return anySparseIn || anySparseOut; +} + // // Reordering. //