diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -223,43 +223,65 @@ return annotated; } -/// A DFS helper to compute a topological sort. Note that recursion is -/// bounded by the number of implicit loops, which is always small. -/// Returns false when a cycle is detected. -static bool topSortDFS(unsigned i, std::vector &visit, - std::vector &topSort, - std::vector> &adjM) { - if (visit[i] != 0) - return visit[i] != 1; // 1 denotes cycle! - visit[i] = 1; - for (unsigned j = 0, e = visit.size(); j < e; j++) - if (adjM[i][j]) - if (!topSortDFS(j, visit, topSort, adjM)) - return false; - visit[i] = 2; - topSort.push_back(i); - return true; +/// A helper to compute a topological sort. O(n^2) time complexity +/// as we use adj matrix for the graph. +/// The sorted result will put the first Reduction iterator to the +/// latest possible index. +static bool topSortOptimal(unsigned n, ArrayRef iteratorTypes, + std::vector &topSort, + std::vector &inDegree, + std::vector> &adjM) { + std::vector redIt; // reduce iterator with 0 degree + std::vector parIt; // parallel iterator with 0 degree + for (unsigned i = 0; i < n; i++) { + if (inDegree[i] == 0) { + if (linalg::isReductionIterator(iteratorTypes[i])) + redIt.push_back(i); + else + parIt.push_back(i); + } + } + + while (!redIt.empty() || !parIt.empty()) { + // We always choose parallel iterator if there is any. + auto &it = !parIt.empty() ? parIt : redIt; + auto src = it.back(); + topSort.push_back(src); + it.pop_back(); + // Update in-degree, and push 0-degree node into worklist. + for (unsigned dst = 0; dst < n; dst++) + if (adjM[src][dst] && --inDegree[dst] == 0) { + if (linalg::isReductionIterator(iteratorTypes[dst])) + redIt.push_back(dst); + else + parIt.push_back(dst); + } + } + return topSort.size() == n; } /// Helper method to add all constraints from the indices in one affine /// expression before all indices in the other affine expression. For /// example i0+i1 < i2+i3+1 yields i0> &adjM, - AffineExpr a, AffineExpr b, unsigned fidx) { + std::vector &inDegree, AffineExpr a, + AffineExpr b, unsigned fidx) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); if (b) - addAffineOrderings(adjM, b, AffineExpr(), idx); - else + addAffineOrderings(adjM, inDegree, b, AffineExpr(), idx); + else if (!adjM[fidx][idx]) { adjM[fidx][idx] = true; + inDegree[idx]++; + } break; } case AffineExprKind::Add: case AffineExprKind::Mul: { auto binOp = a.cast(); - addAffineOrderings(adjM, binOp.getLHS(), b, fidx); - addAffineOrderings(adjM, binOp.getRHS(), b, fidx); + addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx); + addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx); break; } default: @@ -279,7 +301,8 @@ // for the implicit loop indices i_0 .. i_n-1. unsigned n = op.getNumLoops(); std::vector> adjM(n, std::vector(n, false)); - + std::vector inDegree(n, 0); // in-degree of each node. + auto iteratorTypes = op.iterator_types().getValue(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand *t : op.getInputAndOutputOperands()) { // Skip tensor during cycle resolution. @@ -299,7 +322,7 @@ for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { AffineExpr f = map.getResult(perm(enc, d - 1)); AffineExpr t = map.getResult(perm(enc, d)); - addAffineOrderings(adjM, f, t, 0); + addAffineOrderings(adjM, inDegree, f, t, 0); } // Push unrelated loops into sparse iteration space, so these // will be skipped more often. @@ -309,21 +332,17 @@ if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) for (unsigned j = 0; j < n; j++) - if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) + if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) { adjM[i][j] = true; + inDegree[j]++; + } } } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. topSort.clear(); topSort.reserve(n); - std::vector visit(n, 0); - for (unsigned i = 0; i < n; i++) - if (visit[i] == 0) - if (!topSortDFS(i, visit, topSort, adjM)) - return false; // cycle! - std::reverse(std::begin(topSort), std::end(topSort)); - return true; + return topSortOptimal(n, iteratorTypes, topSort, inDegree, adjM); } /// Returns true if tensor materializes uninitialized into the computation. @@ -1271,7 +1290,8 @@ bool isParallel = isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); - assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && "TODO: implement"); + assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && + "TODO: implement"); // Prepare vector length. if (isVector) @@ -1798,33 +1818,42 @@ if (!findSparseAnnotations(merger, op)) return failure(); + // Builds the tensor expression for the Linalg operation in SSA form. + Optional optExp = merger.buildTensorExpFromLinalg(op); + if (!optExp.has_value()) + return failure(); + + unsigned exp = optExp.value(); + OpOperand *sparseOut = nullptr; + unsigned outerParNest = 0; // Computes a topologically sorted iteration graph to ensure tensors // are visited in natural index order. Gradually relaxes the considered // constraints until an acyclic iteration graph results, such that sparse // code generation can proceed. As a last resort, an attempt is made // to resolve cycles by inserting a conversion. std::vector topSort; - if (!computeIterationGraph(merger, op, topSort, SortMask::kIncludeAll) && - !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && - !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && - !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) { - return resolveCycle(merger, rewriter, op); + // Whether the current GenericOp is admissible + bool isAdmissible = false; + // An const list of all masks that we used for interation graph + // computation. Must be ordered from strict -> loose. + const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, + SortMask::kIncludeDense, SortMask::kSparseOnly}; + for (auto mask : allMask) { + if (computeIterationGraph(merger, op, topSort, mask) && + isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, + outerParNest)) { + // This is an admissible GenericOp. + isAdmissible = true; + break; + } + // else try a less strict constraints. } - // Builds the tensor expression for the Linalg operation in SSA form. - Optional optExp = merger.buildTensorExpFromLinalg(op); - if (!optExp.has_value()) - return failure(); - unsigned exp = optExp.value(); - - // Rejects an inadmissable tensor expression. - OpOperand *sparseOut = nullptr; - unsigned outerParNest = 0; - if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, - outerParNest)) - return failure(); + if (!isAdmissible) + // Give it one last shot to resolve the cycle. + return resolveCycle(merger, rewriter, op); - // Recursively generates code. + // Recursively generates code if admissible. merger.setHasSparseOut(sparseOut != nullptr); CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); genBuffers(merger, codegen, rewriter, op); diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -123,55 +123,66 @@ return %3 : tensor<8x8xf64> } -// CHECK-LABEL: func.func @sparse_sampled_dd_unfused( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64> -// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64> -// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> -// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_18:.*]] = memref.alloca(%[[VAL_6]]) : memref -// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_5]] { -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_22]]] : memref -// CHECK: memref.store %[[VAL_23]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_22]]] : memref -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_5]] : index -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]]] : memref -// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] { -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref -// CHECK: memref.store %[[VAL_28]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]]] : memref -// CHECK: %[[VAL_30:.*]] = scf.for %[[VAL_31:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_7]]) -> (f64) { -// CHECK: memref.store %[[VAL_31]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref -// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_31]]] : memref<8x8xf64> -// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]], %[[VAL_28]]] : memref<8x8xf64> -// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_34]] : f64 -// CHECK: %[[VAL_36:.*]] = arith.mulf %[[VAL_35]], %[[VAL_29]] : f64 -// CHECK: %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64 -// CHECK: scf.yield %[[VAL_37]] : f64 -// CHECK: } -// CHECK: memref.store %[[VAL_30:.*]], %[[VAL_19]][] : memref -// CHECK: sparse_tensor.insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref, memref -// CHECK: } -// CHECK: } -// CHECK: %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: return %[[VAL_39]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: } + +// CHECK-LABEL: func @sparse_sampled_dd_unfused( +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding +// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<8x8xf64>, +// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<8x8xf64>) +// CHECK-DAG: %[[TMP_c8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[TMP_false:.*]] = arith.constant false +// CHECK-DAG: %[[TMP_true:.*]] = arith.constant true +// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64> +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() copy(%[[TMP_cst]]) {bufferization.escape = [false]} +// CHECK: %[[TMP_1:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} +// CHECK: %[[TMP_2:.*]] = bufferization.to_memref %[[TMP_arg1]] : memref<8x8xf64> +// CHECK: %[[TMP_3:.*]] = bufferization.to_memref %[[TMP_arg2]] : memref<8x8xf64> +// CHECK: %[[TMP_4:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index} +// CHECK: %[[TMP_5:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index} +// CHECK: %[[TMP_6:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index} +// CHECK: %[[TMP_7:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index} +// CHECK: %[[TMP_8:.*]] = sparse_tensor.values %[[TMP_arg0]] +// CHECK: %[[TMP_9:.*]] = memref.alloca(%[[TMP_c2]]) : memref +// CHECK: %[[TMP_10:.*]] = memref.load %[[TMP_4]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_11:.*]] = memref.load %[[TMP_4]][%[[TMP_c1]]] : memref +// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_10]] to %[[TMP_11]] step %[[TMP_c1]] { +// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_5]][%[[TMP_arg3]]] : memref +// CHECK: memref.store %[[TMP_13]], %[[TMP_9]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_values:.*]], %[[TMP_filled:.*]], %[[TMP_added:.*]], %[[TMP_count:.*]] = sparse_tensor.expand %[[TMP_1]] +// CHECK: %[[TMP_14:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_c0]] to %[[TMP_c8]] step %[[TMP_c1]] iter_args(%[[TMP_arg5:.*]] = %[[TMP_count]]) -> (index) { +// CHECK: %[[TMP_15:.*]] = memref.load %[[TMP_2]][%[[TMP_13]], %[[TMP_arg4]]] : memref<8x8xf64> +// CHECK: %[[TMP_16:.*]] = memref.load %[[TMP_6]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_17:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_18:.*]] = memref.load %[[TMP_6]][%[[TMP_17]]] : memref +// CHECK: %[[TMP_19:.*]] = scf.for %[[TMP_arg6:.*]] = %[[TMP_16]] to %[[TMP_18]] step %[[TMP_c1]] iter_args(%[[TMP_arg7:.*]] = %[[TMP_arg5]]) -> (index) { +// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_7]][%[[TMP_arg6]]] : memref +// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_values]][%[[TMP_20]]] : memref +// CHECK: %[[TMP_22:.*]] = memref.load %[[TMP_3]][%[[TMP_arg4]], %[[TMP_20]]] : memref<8x8xf64> +// CHECK: %[[TMP_23:.*]] = arith.mulf %[[TMP_15]], %[[TMP_22]] : f64 +// CHECK: %[[TMP_24:.*]] = memref.load %[[TMP_8]][%[[TMP_arg6]]] : memref +// CHECK: %[[TMP_25:.*]] = arith.mulf %[[TMP_23]], %[[TMP_24]] : f64 +// CHECK: %[[TMP_26:.*]] = arith.addf %[[TMP_21]], %[[TMP_25]] : f64 +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_filled]][%[[TMP_20]]] : memref +// CHECK: %[[TMP_28:.*]] = arith.cmpi eq, %[[TMP_27]], %[[TMP_false]] : i1 +// CHECK: %[[TMP_29:.*]] = scf.if %[[TMP_28]] -> (index) { +// CHECK: memref.store %[[TMP_true]], %[[TMP_filled]][%[[TMP_20]]] : memref +// CHECK: memref.store %[[TMP_20]], %[[TMP_added]][%[[TMP_arg7]]] : memref +// CHECK: %[[TMP_30:.*]] = arith.addi %[[TMP_arg7]], %[[TMP_c1]] : index +// CHECK: scf.yield %[[TMP_30]] : index +// CHECK: } else { +// CHECK: scf.yield %[[TMP_arg7]] : index +// CHECK: } +// CHECK: memref.store %[[TMP_26]], %[[TMP_values]][%[[TMP_20]]] : memref +// CHECK: scf.yield %[[TMP_29]] : index +// CHECK: } +// CHECK: scf.yield %[[TMP_19]] : index +// CHECK: } +// CHECK: sparse_tensor.compress %[[TMP_1]], %[[TMP_9]], %[[TMP_values]], %[[TMP_filled]], %[[TMP_added]], %[[TMP_14]] +// CHECK: } +// CHECK: %[[TMP_12:.*]] = sparse_tensor.load %[[TMP_1]] hasInserts +// CHECK: return %[[TMP_12]] : tensor<8x8xf64, #sparse_tensor.encoding func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir @@ -24,6 +24,15 @@ return %0 : tensor<6x6xi32> } + func.func @conv2d_sparse_out(%input: tensor<8x8xi32>, + %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> { + %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR> + %0 = linalg.conv_2d + ins (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32, #DCSR>) + outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> + return %0 : tensor<6x6xi32, #DCSR> + } + func.func @entry() { %c0 = arith.constant 0 : index %i0 = arith.constant 0 : i32 @@ -53,7 +62,10 @@ %0 = call @conv2d(%input, %sparse_filter, %output) : (tensor<8x8xi32>, tensor<3x3xi32, #DCSR>, tensor<6x6xi32>) -> tensor<6x6xi32> - + %1 = call @conv2d_sparse_out(%input, %sparse_filter) + : (tensor<8x8xi32>, + tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> + // Verify the output. // // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), @@ -67,9 +79,24 @@ : tensor<6x6xi32>, vector<6x6xi32> vector.print %v : vector<6x6xi32> + // + // Should be the same as dense output + // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), + // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ), + // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ), + // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ), + // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) ) + // + %sparse_ret = sparse_tensor.convert %1 + : tensor<6x6xi32, #DCSR> to tensor<6x6xi32> + %v1 = vector.transfer_read %sparse_ret[%c0, %c0], %i0 + : tensor<6x6xi32>, vector<6x6xi32> + vector.print %v1 : vector<6x6xi32> + // Release the resources. bufferization.dealloc_tensor %sparse_filter : tensor<3x3xi32, #DCSR> - + bufferization.dealloc_tensor %1 : tensor<6x6xi32, #DCSR> return } }