diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -190,6 +190,10 @@ Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim, Value iv); + /// Generates instructions to compute the coordinate of tesnors[tid] on `l` + /// under the current loop context. + Value genSparseCoord(OpBuilder &builder, Location loc, size_t tid, size_t l); + bool isOutputTensor(size_t tid) { return hasOutput && tid == tensors.size() - 1; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -127,6 +127,25 @@ return add; } +Value LoopEmitter::genSparseCoord(OpBuilder &builder, Location loc, size_t tid, + size_t l) { + Value c = constantIndex(builder, loc, 0); + auto reass = getCollapseReassociation(tid, l); + for (unsigned i = 0; i < reass.size(); i++) { + auto lvl = reass[i]; + // A load on the indices array yields the coordinate. + Value ptr = idxBuffer[tid][lvl]; + Value off = genIndexLoad(builder, loc, ptr, pidxs[tid][l]); + // Linearized the coordinates within the same collapse reassociation. + c = builder.create(loc, c, off); + if (i != reass.size() - 1) { + c = builder.create(loc, c, + this->lvlSizes[tid][reass[i + 1]]); + } + } + return c; +} + LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, bool isSparseOut, ArrayRef topSort) { initialize(tensors, loopTag, hasOutput, isSparseOut, topSort); @@ -383,20 +402,9 @@ Value c; if (isSparseInput) { assert(reass.size() == 1 || isUniqueCOOType(tensors[tid].getType())); - c = constantIndex(builder, loc, 0); - for (unsigned i = 0; i < reass.size(); i++) { - auto lvl = reass[i]; - // For COO, the pidxs are always the same across consecutive levels. - pidxs[tid][lvl] = iv; - // Generating a load on the indices array yields the coordinate. - Value ptr = idxBuffer[tid][lvl]; - Value off = genIndexLoad(builder, loc, ptr, iv); - c = builder.create(loc, c, off); - if (i != reass.size() - 1) { - c = builder.create(loc, c, - this->lvlSizes[tid][reass[i + 1]]); - } - } + // For COO, the position is the same across consecutive levels. + llvm::for_each(reass, [this, tid, iv](int lvl) { pidxs[tid][lvl] = iv; }); + c = genSparseCoord(builder, loc, tid, dim); } else { // Dense tensor, the coordinates is the inducation variable. c = iv; @@ -555,16 +563,22 @@ builder.setInsertionPointToStart(&whileOp.getBefore().front()); Value cond; unsigned o = 0; - for (auto [tid, dim] : llvm::zip(tids, dims)) { - if (isCompressedDLT(dimTypes[tid][dim]) || - isSingletonDLT(dimTypes[tid][dim])) { + for (auto [t, lvl] : llvm::zip(tids, dims)) { + unsigned tid = t; // Why `t` can not be captured by lambda? + if (isCompressedDLT(dimTypes[tid][lvl]) || + isSingletonDLT(dimTypes[tid][lvl])) { Value op1 = before->getArgument(o); - Value op2 = highs[tid][dim]; + Value op2 = highs[tid][lvl]; Value opc = builder.create(loc, arith::CmpIPredicate::ult, op1, op2); cond = cond ? builder.create(loc, cond, opc) : opc; - // Update - pidxs[tid][dim] = after->getArgument(o++); + // Update positions + Value pos = after->getArgument(o++); + auto reass = getCollapseReassociation(tid, lvl); + assert(reass.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + // For COO, the position is the same across consecutive levels. + llvm::for_each(reass, + [this, tid, pos](int lvl) { pidxs[tid][lvl] = pos; }); } } builder.create(loc, cond, before->getArguments()); @@ -578,11 +592,10 @@ // Prepares for next level. if (isCompressedDLT(dimTypes[tid][dim]) || isSingletonDLT(dimTypes[tid][dim])) { - Value ptr = idxBuffer[tid][dim]; - Value s = pidxs[tid][dim]; - Value load = genIndexLoad(builder, loc, ptr, s); - coord[tid][dim] = load; + coord[tid][dim] = genSparseCoord(builder, loc, tid, dim); if (isSparseSlices[tid]) { + Value load = + genIndexLoad(builder, loc, idxBuffer[tid][dim], pidxs[tid][dim]); auto enc = getSparseTensorEncoding(tensors[tid].getType()); auto [trans, pred] = genSliceLegitPredicate(builder, loc, load, enc, dim); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir @@ -23,7 +23,7 @@ func.func private @printMemref3dF32(%ptr : tensor) attributes { llvm.emit_c_interface } func.func private @printMemref2dF32(%ptr : tensor) attributes { llvm.emit_c_interface } - func.func @test_sparse(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor { + func.func @test_sparse_rhs(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor { %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D> %0 = tensor.empty() : tensor<5x6xf32> %cst = arith.constant 0.000000e+00 : f32 @@ -34,6 +34,17 @@ return %ret1 : tensor } + func.func @test_sparse_all(%arg0: tensor<5x6xf32, #COO_2D>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor { + %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D> + %0 = tensor.empty() : tensor<5x6xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32> + %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32, #COO_2D>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32> + %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32> + %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor + return %ret1 : tensor + } + func.func @test_dense(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32>) -> tensor { %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32> into tensor<6x6xf32> %0 = tensor.empty() : tensor<5x6xf32> @@ -58,6 +69,9 @@ [ 6.0, 7.0, 8.0] > : tensor<6x2x3xf32> + %s1 = sparse_tensor.convert %d1 : tensor<5x6xf32> to tensor<5x6xf32, #COO_2D> + %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D> + // CHECK: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data = // CHECK-NEXT:[ // CHECK-SAME: [ @@ -96,13 +110,35 @@ // CHECK-NEXT: [ // CHECK-SAME: [0, 0, 0], // CHECK-NEXT: [0, 0, 0]]] - %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D> - %so1 = call @test_sparse(%d1, %s2): (tensor<5x6xf32>, tensor<6x2x3xf32, #COO_3D>) -> tensor + %so1 = call @test_sparse_rhs(%d1, %s2): (tensor<5x6xf32>, tensor<6x2x3xf32, #COO_3D>) -> tensor call @printMemref3dF32(%so1) : (tensor) -> () + // Same results. + // CHECK-NEXT: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data = + // CHECK-NEXT:[ + // CHECK-SAME: [ + // CHECK-SAME: [6, 0, 0], + // CHECK-NEXT: [0, 0, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 14, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 24, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 0, 0]], + // CHECK-NEXT: [ + // CHECK-SAME: [0, 0, 0], + // CHECK-NEXT: [0, 0, 0]]] + %so2 = call @test_sparse_all(%s1, %s2): (tensor<5x6xf32, #COO_2D>, tensor<6x2x3xf32, #COO_3D>) -> tensor + call @printMemref3dF32(%so2) : (tensor) -> () + + bufferization.dealloc_tensor %s1 : tensor<5x6xf32, #COO_2D> bufferization.dealloc_tensor %s2 : tensor<6x2x3xf32, #COO_3D> bufferization.dealloc_tensor %do1 : tensor bufferization.dealloc_tensor %so1 : tensor + bufferization.dealloc_tensor %so2 : tensor return } }