diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -454,6 +454,7 @@ bool Merger::expContainsTensor(ExprId e, TensorId t) const { const auto &expr = exp(e); + // First we check `expIsTensor`. if (expr.kind == TensorExp::Kind::kTensor) return expr.tensor == t; @@ -462,15 +463,11 @@ return false; case ExpArity::kUnary: { const ExprId e0 = expr.children.e0; - if (expIsTensor(e0, t)) - return true; return expContainsTensor(e0, t); } case ExpArity::kBinary: { const ExprId e0 = expr.children.e0; const ExprId e1 = expr.children.e1; - if (expIsTensor(e0, t) || expIsTensor(e1, t)) - return true; return expContainsTensor(e0, t) || expContainsTensor(e1, t); } }