diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -115,9 +115,16 @@ static AffineMap permute(MLIRContext *context, AffineMap m, std::vector &topSort) { unsigned sz = topSort.size(); + assert(m.getNumResults() == sz && "TopoSort/AffineMap size mismatch"); + // Construct the inverse of `m`; to avoid the asymptotic complexity + // of calling `m.getPermutedPosition` repeatedly. + SmallVector inv(sz); + for (unsigned i = 0; i < sz; i++) + inv[i] = m.getDimPosition(i); + // Construct the permutation. SmallVector perm(sz); for (unsigned i = 0; i < sz; i++) - perm[i] = m.getPermutedPosition(topSort[i]); + perm[i] = inv[topSort[i]]; return AffineMap::getPermutationMap(perm, context); }