diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -223,43 +223,70 @@ return annotated; } -/// A DFS helper to compute a topological sort. Note that recursion is -/// bounded by the number of implicit loops, which is always small. -/// Returns false when a cycle is detected. -static bool topSortDFS(unsigned i, std::vector &visit, - std::vector &topSort, - std::vector> &adjM) { - if (visit[i] != 0) - return visit[i] != 1; // 1 denotes cycle! - visit[i] = 1; - for (unsigned j = 0, e = visit.size(); j < e; j++) - if (adjM[i][j]) - if (!topSortDFS(j, visit, topSort, adjM)) - return false; - visit[i] = 2; - topSort.push_back(i); - return true; +/// A helper to compute a topological sort. O(n^2) time complexity +/// as we use adj matrix for the graph. +/// The sorted result will put the first Reduction interator to the +/// latest possible index. +static bool topSortOptimal(unsigned n, ArrayRef iteratorTypes, + std::vector &topSort, + std::vector &inDegree, + std::vector> &adjM) { + std::vector redIt; // reduce iterator with 0 degree + std::vector parIt; // parallel iterator with 0 degree + for (unsigned i = 0; i < n; i++) { + if (inDegree[i] == 0) { + if (linalg::isReductionIterator(iteratorTypes[i])) + redIt.push_back(i); + else + parIt.push_back(i); + } + } + auto subDegree = [&](unsigned src) { + for (unsigned dst = 0; dst < n; dst++) + if (adjM[src][dst] && --inDegree[dst] == 0) { + if (linalg::isReductionIterator(iteratorTypes[dst])) + redIt.push_back(dst); + else + parIt.push_back(dst); + } + }; + + while (!redIt.empty() || !parIt.empty()) { + // We always choose parallel iterator if there is any + auto &it = !parIt.empty() ? parIt : redIt; + auto i = it.back(); + topSort.push_back(i); + it.pop_back(); + // Update in degree + subDegree(i); + } + return topSort.size() == n; } /// Helper method to add all constraints from the indices in one affine /// expression before all indices in the other affine expression. For /// example i0+i1 < i2+i3+1 yields i0> &adjM, - AffineExpr a, AffineExpr b, unsigned fidx) { + std::vector &inDegree, AffineExpr a, + AffineExpr b, unsigned fidx) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); if (b) - addAffineOrderings(adjM, b, AffineExpr(), idx); - else - adjM[fidx][idx] = true; + addAffineOrderings(adjM, inDegree, b, AffineExpr(), idx); + else { + if (!adjM[fidx][idx]) { + adjM[fidx][idx] = true; + inDegree[idx]++; + } + } break; } case AffineExprKind::Add: case AffineExprKind::Mul: { auto binOp = a.cast(); - addAffineOrderings(adjM, binOp.getLHS(), b, fidx); - addAffineOrderings(adjM, binOp.getRHS(), b, fidx); + addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx); + addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx); break; } default: @@ -279,7 +306,8 @@ // for the implicit loop indices i_0 .. i_n-1. unsigned n = op.getNumLoops(); std::vector> adjM(n, std::vector(n, false)); - + std::vector inDegree(n, 0); // in degree of each node. + auto iteratorTypes = op.iterator_types().getValue(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand *t : op.getInputAndOutputOperands()) { // Skip tensor during cycle resolution. @@ -299,7 +327,7 @@ for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { AffineExpr f = map.getResult(perm(enc, d - 1)); AffineExpr t = map.getResult(perm(enc, d)); - addAffineOrderings(adjM, f, t, 0); + addAffineOrderings(adjM, inDegree, f, t, 0); } // Push unrelated loops into sparse iteration space, so these // will be skipped more often. @@ -309,21 +337,17 @@ if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) for (unsigned j = 0; j < n; j++) - if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) + if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) { adjM[i][j] = true; + inDegree[j]++; + } } } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. topSort.clear(); topSort.reserve(n); - std::vector visit(n, 0); - for (unsigned i = 0; i < n; i++) - if (visit[i] == 0) - if (!topSortDFS(i, visit, topSort, adjM)) - return false; // cycle! - std::reverse(std::begin(topSort), std::end(topSort)); - return true; + return topSortOptimal(n, iteratorTypes, topSort, inDegree, adjM); } /// Returns true if tensor materializes uninitialized into the computation. @@ -1271,7 +1295,8 @@ bool isParallel = isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); - assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && "TODO: implement"); + assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && + "TODO: implement"); // Prepare vector length. if (isVector) @@ -1641,8 +1666,10 @@ unsigned lsize = merger.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { unsigned li = merger.set(lts)[i]; - if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) && - !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton)) + if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, + DimLvlType::kCompressed) && + !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, + DimLvlType::kSingleton)) return true; } } @@ -1811,7 +1838,6 @@ !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) { return resolveCycle(merger, rewriter, op); } - // Builds the tensor expression for the Linalg operation in SSA form. Optional optExp = merger.buildTensorExpFromLinalg(op); if (!optExp.has_value()) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir @@ -24,6 +24,15 @@ return %0 : tensor<6x6xi32> } + func.func @conv2d_sparse_out(%input: tensor<8x8xi32>, + %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> { + %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR> + %0 = linalg.conv_2d + ins (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32, #DCSR>) + outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> + return %0 : tensor<6x6xi32, #DCSR> + } + func.func @entry() { %c0 = arith.constant 0 : index %i0 = arith.constant 0 : i32 @@ -53,7 +62,10 @@ %0 = call @conv2d(%input, %sparse_filter, %output) : (tensor<8x8xi32>, tensor<3x3xi32, #DCSR>, tensor<6x6xi32>) -> tensor<6x6xi32> - + %1 = call @conv2d_sparse_out(%input, %sparse_filter) + : (tensor<8x8xi32>, + tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> + // Verify the output. // // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), @@ -67,9 +79,24 @@ : tensor<6x6xi32>, vector<6x6xi32> vector.print %v : vector<6x6xi32> + // + // Should be the same as dense output + // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), + // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ), + // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ), + // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ), + // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) ) + // + %sparse_ret = sparse_tensor.convert %1 + : tensor<6x6xi32, #DCSR> to tensor<6x6xi32> + %v1 = vector.transfer_read %sparse_ret[%c0, %c0], %i0 + : tensor<6x6xi32>, vector<6x6xi32> + vector.print %v1 : vector<6x6xi32> + // Release the resources. bufferization.dealloc_tensor %sparse_filter : tensor<3x3xi32, #DCSR> - + bufferization.dealloc_tensor %1 : tensor<6x6xi32, #DCSR> return } }