diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -267,6 +267,11 @@ return ldx >= numNativeLoops; } + /// Returns true if the expression is `(kTensor t)`. + bool expIsTensor(unsigned e, unsigned t) const { + return tensorExps[e].kind == kTensor && tensorExps[e].tensor == t; + } + /// Returns true if the expression contains the `t` as an operand. bool expContainsTensor(unsigned e, unsigned t) const; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -318,7 +318,7 @@ if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) + if (expIsTensor(e, outTensor)) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { @@ -405,15 +405,14 @@ return false; case ExpArity::kUnary: { unsigned op = tensorExps[e].children.e0; - if (tensorExps[op].kind == kTensor && tensorExps[op].tensor == t) + if (expIsTensor(op, t)) return true; return expContainsTensor(op, t); } case ExpArity::kBinary: { unsigned op1 = tensorExps[e].children.e0; unsigned op2 = tensorExps[e].children.e1; - if ((tensorExps[op1].kind == kTensor && tensorExps[op1].tensor == t) || - (tensorExps[op2].kind == kTensor && tensorExps[op2].tensor == t)) + if (expIsTensor(op1, t) || expIsTensor(op2, t)) return true; return expContainsTensor(op1, t) || expContainsTensor(op2, t); }