diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -279,7 +279,7 @@ // for the implicit loop indices i_0 .. i_n-1. unsigned n = op.getNumLoops(); std::vector> adjM(n, std::vector(n, false)); - + auto iteratorTypes = op.iterator_types().getValue(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand *t : op.getInputAndOutputOperands()) { // Skip tensor during cycle resolution. @@ -318,8 +318,19 @@ topSort.clear(); topSort.reserve(n); std::vector visit(n, 0); + // Start the topological sort with Reduction iterator first (then parallel + // iterator). This is a heuristic such that Reduction iterators appears later + // in the topologically sorted result, which increase the chance to make the + // GenericOp admissible later in isAdmissableTensorExp(). + // NOTE: This should still be suboptimal, it might not guarantee that the + // first reduction iterator appears in the lastest possible position. + for (unsigned i = 0; i < n; i++) + if (visit[i] == 0 && linalg::isReductionIterator(iteratorTypes[i])) + if (!topSortDFS(i, visit, topSort, adjM)) + return false; // cycle! + // Then try parallel iterator. for (unsigned i = 0; i < n; i++) - if (visit[i] == 0) + if (visit[i] == 0 && !linalg::isReductionIterator(iteratorTypes[i])) if (!topSortDFS(i, visit, topSort, adjM)) return false; // cycle! std::reverse(std::begin(topSort), std::end(topSort)); @@ -1271,7 +1282,8 @@ bool isParallel = isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); - assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && "TODO: implement"); + assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && + "TODO: implement"); // Prepare vector length. if (isVector) @@ -1641,8 +1653,10 @@ unsigned lsize = merger.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { unsigned li = merger.set(lts)[i]; - if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) && - !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton)) + if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, + DimLvlType::kCompressed) && + !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, + DimLvlType::kSingleton)) return true; } } @@ -1811,7 +1825,6 @@ !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) { return resolveCycle(merger, rewriter, op); } - // Builds the tensor expression for the Linalg operation in SSA form. Optional optExp = merger.buildTensorExpFromLinalg(op); if (!optExp.has_value()) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir @@ -24,6 +24,15 @@ return %0 : tensor<6x6xi32> } + func.func @conv2d_sparse_out(%input: tensor<8x8xi32>, + %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> { + %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR> + %0 = linalg.conv_2d + ins (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32, #DCSR>) + outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> + return %0 : tensor<6x6xi32, #DCSR> + } + func.func @entry() { %c0 = arith.constant 0 : index %i0 = arith.constant 0 : i32 @@ -53,7 +62,10 @@ %0 = call @conv2d(%input, %sparse_filter, %output) : (tensor<8x8xi32>, tensor<3x3xi32, #DCSR>, tensor<6x6xi32>) -> tensor<6x6xi32> - + %1 = call @conv2d_sparse_out(%input, %sparse_filter) + : (tensor<8x8xi32>, + tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> + // Verify the output. // // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), @@ -67,9 +79,24 @@ : tensor<6x6xi32>, vector<6x6xi32> vector.print %v : vector<6x6xi32> + // + // Should be the same as dense output + // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), + // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ), + // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ), + // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ), + // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) ) + // + %sparse_ret = sparse_tensor.convert %1 + : tensor<6x6xi32, #DCSR> to tensor<6x6xi32> + %v1 = vector.transfer_read %sparse_ret[%c0, %c0], %i0 + : tensor<6x6xi32>, vector<6x6xi32> + vector.print %v1 : vector<6x6xi32> + // Release the resources. bufferization.dealloc_tensor %sparse_filter : tensor<3x3xi32, #DCSR> - + bufferization.dealloc_tensor %1 : tensor<6x6xi32, #DCSR> return } }