diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -407,12 +407,11 @@ auto enc = getSparseTensorEncoding(tensors[tid].getType()); const auto reassoc = getCollapseReassociation(tid, dim); - dim = reassoc.front(); // TODO: support dynamic slices. Value step = constantIndex(builder, loc, 1); - Value lo = isSparseInput ? pidxs[tid][dim] // current offset - : loopSeqStack.back(); // universal index - Value hi = highs[tid][dim]; + Value lo = isSparseInput ? pidxs[tid][reassoc.front()] // current offset + : loopSeqStack.back(); // universal index + Value hi = highs[tid][reassoc.front()]; Operation *loop = nullptr; Value iv; @@ -585,9 +584,17 @@ for (auto [tid, dim] : llvm::zip(tids, dims)) { if (isCompressedDLT(dimTypes[tid][dim]) || isSingletonDLT(dimTypes[tid][dim])) { + const auto reassoc = getCollapseReassociation(tid, dim); + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + if (!isUniqueDLT(dimTypes[tid][reassoc[i]])) { + // This is the segment high for each non-unique levels. + types.push_back(indexType); + operands.push_back(constantIndex(builder, loc, 0)); + } + } assert(pidxs[tid][dim]); types.push_back(indexType); - operands.push_back(pidxs[tid][dim]); + operands.push_back(pidxs[tid][reassoc.front()]); } } // The position where user-supplied reduction variable starts. @@ -616,15 +623,22 @@ unsigned tid = t; // Why `t` can not be captured by lambda? if (isCompressedDLT(dimTypes[tid][lvl]) || isSingletonDLT(dimTypes[tid][lvl])) { + const auto reassoc = getCollapseReassociation(tid, lvl); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + if (!isUniqueDLT(dimTypes[tid][reassoc[i]])) { + // Links the SSA chain for segHi. + segHi[tid][reassoc[i]] = after->getArgument(o++); + } + } Value op1 = before->getArgument(o); - Value op2 = highs[tid][lvl]; + // We used the first level bound as the bound the collapsed set of levels. + Value op2 = highs[tid][reassoc.front()]; Value opc = builder.create(loc, arith::CmpIPredicate::ult, op1, op2); cond = cond ? builder.create(loc, cond, opc) : opc; // Update positions Value pos = after->getArgument(o++); - const auto reassoc = getCollapseReassociation(tid, lvl); - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); // For COO, the position is the same across consecutive levels. llvm::for_each(reassoc, [this, tid, pos](Level lvl) { pidxs[tid][lvl] = pos; }); @@ -714,9 +728,47 @@ assert(loopStack.size() == loopSeqStack.size()); for (auto [tid, dim] : llvm::zip(tids, dims)) { - if (!isUniqueDLT(dimTypes[tid][dim])) { - segHi[tid][dim] = genSegmentHigh(builder, loc, tid, dim, pidxs[tid][dim], - highs[tid][dim]); + const auto reassoc = getCollapseReassociation(tid, dim); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + // NOTE: For all the collapsed level (except for the last one, that is why + // the loop ends with `reassoc.size() - 1`), as each iteration is advanced + // by the segment size of the last level, which does not always invalidate + // the segment size for the previous levels, thus we need to propagate the + // segment sizes across loop iterations and only forwards if needed. + // + // E.g., for a COO tensor with the following coordinates array. + // (0, 0, 1), + // (0, 0, 2), + // (1, 1, 1), + // segHi[lvl=0] = segHi[lvl=1] = 2 + // segHi[lvl=2] = 1, + // the first iteration does not invalidate segHi[0] and segHi[1] + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + const auto lvl = reassoc[i]; + if (!isUniqueDLT(dimTypes[tid][lvl])) { + Value pos = pidxs[tid][lvl]; + assert(segHi[tid][lvl]); + Value newSegHi = builder.create( + loc, arith::CmpIPredicate::uge, pos, segHi[tid][lvl]); + auto ifOp = builder.create(loc, builder.getIndexType(), + newSegHi, true); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(ifOp.thenBlock()); + builder.create( + loc, + genSegmentHigh(builder, loc, tid, lvl, pos, highs[tid][lvl])); + // Else, resues the same segment high. + builder.setInsertionPointToStart(ifOp.elseBlock()); + builder.create(loc, segHi[tid][lvl]); + } + highs[tid][lvl + 1] = segHi[tid][lvl] = ifOp.getResult(0); + } + }; + const auto lvl = reassoc.back(); + if (!isUniqueDLT(dimTypes[tid][lvl])) { + segHi[tid][lvl] = genSegmentHigh(builder, loc, tid, lvl, pidxs[tid][lvl], + highs[tid][lvl]); } } @@ -906,6 +958,15 @@ for (auto [tid, dim] : llvm::zip(tids, dims)) { if (isCompressedDLT(dimTypes[tid][dim]) || isSingletonDLT(dimTypes[tid][dim])) { + const auto reassoc = getCollapseReassociation(tid, dim); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + const auto lvl = reassoc[i]; + if (!isUniqueDLT(dimTypes[tid][lvl])) { + operands.push_back(segHi[tid][lvl]); + o++; + } + } Value op1 = coord[tid][dim]; Value op3 = pidxs[tid][dim]; Value cmp = @@ -913,13 +974,18 @@ // If the loop contains a coiteration with non-unique level, we fast // forward all the duplicated coords by setting the position to the // segment high. - Value add = !isUniqueDLT(dimTypes[tid][dim]) - ? segHi[tid][dim] + // If this is a collapsed dim, we forward pidx based on the last level in + // the collapsed level set. + Value add = !isUniqueDLT(dimTypes[tid][reassoc.back()]) + ? segHi[tid][reassoc.back()] : builder.create(loc, op3, one); + operands.push_back(builder.create(loc, cmp, add, op3)); // Following loops continue iteration from the break point of the // current while loop. - pidxs[tid][dim] = whileOp->getResult(o++); + Value pos = whileOp->getResult(o++); + const auto t = tid; + llvm::for_each(reassoc, [this, t, pos](Level l) { pidxs[t][l] = pos; }); // The coordinates are invalid now. coord[tid][dim] = nullptr; // The segment high are invalid now diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir @@ -56,6 +56,18 @@ return %ret1 : tensor } + func.func @test_sparse_all_2(%arg0: tensor<5x6xf32, #COO_2D>, %arg1: tensor<2x3x6xf32, #COO_3D>) -> tensor { + // collapse the first two level this time, as this is the level requires coiterations. + %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2]] : tensor<2x3x6xf32, #COO_3D> into tensor<6x6xf32, #COO_2D> + %0 = tensor.empty() : tensor<5x6xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32> + %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32, #COO_2D>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32> + %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32> + %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor + return %ret1 : tensor + } + func.func @entry() { // Setup two sparse vectors. @@ -68,9 +80,12 @@ [ [0, 0, 0], [1, 1, 1], [2, 1, 1] ], [ 6.0, 7.0, 8.0] > : tensor<6x2x3xf32> + %shape = arith.constant dense<[2, 3, 6]> : tensor<3xi32> + %d3 = tensor.reshape %d2(%shape): (tensor<6x2x3xf32>, tensor<3xi32>) -> tensor<2x3x6xf32> %s1 = sparse_tensor.convert %d1 : tensor<5x6xf32> to tensor<5x6xf32, #COO_2D> %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D> + %s3 = sparse_tensor.convert %d3 : tensor<2x3x6xf32> to tensor<2x3x6xf32, #COO_3D> // CHECK: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data = // CHECK-NEXT:[ @@ -134,11 +149,34 @@ %so2 = call @test_sparse_all(%s1, %s2): (tensor<5x6xf32, #COO_2D>, tensor<6x2x3xf32, #COO_3D>) -> tensor call @printMemref3dF32(%so2) : (tensor) -> () + // Same results. + // CHECK-NEXT: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data = + // CHECK-NEXT:[ + // CHECK-SAME: [ + // CHECK-SAME: [6, 0, 0], + // CHECK-NEXT: [0, 0, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 14, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 24, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 0, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 0, 0]]] + %so3 = call @test_sparse_all_2(%s1, %s3): (tensor<5x6xf32, #COO_2D>, tensor<2x3x6xf32, #COO_3D>) -> tensor + call @printMemref3dF32(%so2) : (tensor) -> () + bufferization.dealloc_tensor %s1 : tensor<5x6xf32, #COO_2D> bufferization.dealloc_tensor %s2 : tensor<6x2x3xf32, #COO_3D> + bufferization.dealloc_tensor %s3 : tensor<2x3x6xf32, #COO_3D> bufferization.dealloc_tensor %do1 : tensor bufferization.dealloc_tensor %so1 : tensor bufferization.dealloc_tensor %so2 : tensor + bufferization.dealloc_tensor %so3 : tensor return } }