diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -31,6 +31,11 @@ // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// +// Helper method to match any typed zero. +static bool isZeroValue(Value val) { + return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()); +} + // Helper to detect a sparse tensor type operand. static bool isSparseTensor(OpOperand *op) { if (auto enc = getSparseTensorEncoding(op->get().getType())) { @@ -47,8 +52,7 @@ if (auto alloc = val.getDefiningOp()) { Value copy = alloc.getCopy(); if (isZero) - return copy && (matchPattern(copy, m_Zero()) || - matchPattern(copy, m_AnyZeroFloat())); + return copy && isZeroValue(copy); return !copy; } return false; @@ -100,13 +104,10 @@ if (auto arg = yieldOp.getOperand(0).dyn_cast()) { if (arg.getOwner()->getParentOp() == op) { OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()]; - return matchPattern(t->get(), m_Zero()) || - matchPattern(t->get(), m_AnyZeroFloat()); + return isZeroValue(t->get()); } - } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { - return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat()); } - return false; + return isZeroValue(yieldOp.getOperand(0)); } //===---------------------------------------------------------------------===//