diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1832,26 +1832,30 @@ // code generation can proceed. As a last resort, an attempt is made // to resolve cycles by inserting a conversion. std::vector topSort; - // Whether the current GenericOp is admissible + // Whether the current GenericOp is admissible. bool isAdmissible = false; + bool hasCycle = true; // An const list of all masks that we used for interation graph // computation. Must be ordered from strict -> loose. const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, SortMask::kIncludeDense, SortMask::kSparseOnly}; - for (auto mask : allMask) { - if (computeIterationGraph(merger, op, topSort, mask) && - isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, - outerParNest)) { - // This is an admissible GenericOp. - isAdmissible = true; - break; + for (auto mask : allMask) + if (computeIterationGraph(merger, op, topSort, mask)) { + hasCycle = false; + if (isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, + outerParNest)) { + isAdmissible = true; + break; + } + // else try a set of less strict constraints. } - // else try a less strict constraints. - } - if (!isAdmissible) + if (hasCycle) // Give it one last shot to resolve the cycle. return resolveCycle(merger, rewriter, op); + if (!isAdmissible) + // Inadmissible expression, reject. + return failure(); // Recursively generates code if admissible. merger.setHasSparseOut(sparseOut != nullptr);