diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -357,6 +357,14 @@ } } +/// Returns true if tensor was set up with sparse storage scheme. +static bool linkedSparse(linalg::GenericOp op, unsigned tensor) { + if (tensor < op.getNumInputs()) + return isa_and_nonnull( + op.getInput(tensor).getDefiningOp()); + return false; +} + /// A DFS helper to compute a topological sort. Note that recursion is /// bounded by the number of implicit loops, which is always small. /// Returns false when a cycle is detected. @@ -394,7 +402,7 @@ auto map = op.getIndexingMap(t); assert(map.getNumDims() == n); // Skip dense tensor constraints when sparse only is requested. - if (sparseOnly && !merger.isSparseTensor(t)) + if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t)) continue; // At the moment, we take the index variables in the tensor access // expression in the order in which they appear (conceptually a @@ -513,14 +521,6 @@ llvm_unreachable("unexpected SparseIntType"); } -/// Returns true if tensor was set up with sparse storage scheme. -static bool linkedSparse(linalg::GenericOp op, unsigned tensor) { - if (tensor < op.getNumInputs()) - return isa_and_nonnull( - op.getInput(tensor).getDefiningOp()); - return false; -} - /// Generates buffer for the output tensor. static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, MemRefType denseTp, @@ -1004,7 +1004,7 @@ if (needsUniv) { types.push_back(indexType); assert(codegen.loops[idx].getType().isa() && - "type_mismatch for universal index"); + "type mismatch for universal index"); operands.push_back(codegen.loops[idx]); } Location loc = op.getLoc();