diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -356,10 +356,14 @@ /// whether the out tensor in the tensor expression codegen is admissable. /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective /// nesting depth when a "truly dynamic" sparse tensor output occurs. -static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, - std::vector &topSort, unsigned exp, - OpOperand **sparseOut, - unsigned &outerParNest) { +/// Return true on accept. +/// Return false on reject. +/// Return llvm::None on try a looser order. +static Optional isAdmissableTensorExp(Merger &merger, + linalg::GenericOp op, + std::vector &topSort, + unsigned exp, OpOperand **sparseOut, + unsigned &outerParNest) { OpOperand *lhs = op.getOutputOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); @@ -402,6 +406,8 @@ outerParNest = nest; return true; } + // The current loop sequences is too strict. + return llvm::None; } return false; } @@ -1832,26 +1838,31 @@ // code generation can proceed. As a last resort, an attempt is made // to resolve cycles by inserting a conversion. std::vector topSort; - // Whether the current GenericOp is admissible - bool isAdmissible = false; + // Whether the current GenericOp is admissible. + Optional isAdmissible = llvm::None; // An const list of all masks that we used for interation graph // computation. Must be ordered from strict -> loose. const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, SortMask::kIncludeDense, SortMask::kSparseOnly}; for (auto mask : allMask) { - if (computeIterationGraph(merger, op, topSort, mask) && - isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, - outerParNest)) { - // This is an admissible GenericOp. - isAdmissible = true; + if (computeIterationGraph(merger, op, topSort, mask)) { + isAdmissible = isAdmissableTensorExp(merger, op, topSort, exp, + &sparseOut, outerParNest); + if (!isAdmissible) + // Due to cycle, try looser constraints. + continue; + // Either admissible or inadmissable due to other reasons. break; } // else try a less strict constraints. } - if (!isAdmissible) + if (!isAdmissible) // None // Give it one last shot to resolve the cycle. return resolveCycle(merger, rewriter, op); + if (!isAdmissible.value()) + // Inadmissible due to other reason, simply reject. + return failure(); // Recursively generates code if admissible. merger.setHasSparseOut(sparseOut != nullptr);