diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -620,7 +620,6 @@ Value src = op.getSource(); RankedTensorType srcTp = src.getType().cast(); RankedTensorType dstTp = op.getType().cast(); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); SmallVector srcSizes; @@ -640,17 +639,15 @@ loc, src, tmpCoo, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - SmallVector indices; - for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { - uint64_t dim = toStoredDim(encSrc, i); - indices.push_back(args[dim]); - } - auto t = builder.create(loc, v, reduc.front(), indices); + // The resulting COO tensor has identity ordering. + auto t = builder.create(loc, v, reduc.front(), + args.slice(0, srcTp.getRank())); builder.create(loc, t); }); src = rewriter.create(loc, foreachOp.getResult(0), true); } + SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); // Sort the COO tensor so that its elements are ordered via increasing // indices for the storage ordering of the dst tensor. auto dynShape = {ShapedType::kDynamicSize}; @@ -682,14 +679,14 @@ getDynamicSizes(dstTp, srcSizes, dynDstSizes); Value dst = rewriter.create(loc, dstTp, dynDstSizes).getResult(); + SmallVector indices(srcTp.getRank(), Value()); auto foreachOp = rewriter.create( loc, src, dst, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - SmallVector indices; for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { uint64_t dim = toStoredDim(encDst, i); - indices.push_back(args[dim]); + indices[dim] = args[i]; } auto t = builder.create(loc, v, reduc.front(), indices); builder.create(loc, t); diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -25,6 +25,16 @@ dimLevelType = ["compressed"] }> +#SortedWRT3D = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton-nu", "singleton" ] + +}> + +#TsssPermuted = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed" ], + dimOrdering = affine_map<(i,j,k) -> (k,i,j)> +}> + // CHECK-LABEL: func @sparse_nop_convert( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK: return %[[A]] : !llvm.ptr @@ -146,3 +156,31 @@ %0 = sparse_tensor.convert %arg0 : tensor to tensor return %0 : tensor } + +// CHECK-WRT-LABEL: func.func @sparse_convert_permuted( +// CHECK-WRT-SAME: %[[COO:.*]]: +// CHECK-WRT-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-WRT-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-WRT-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-WRT: %[[D0:.*]] = tensor.dim %[[COO]], %[[C0]] +// CHECK-WRT: %[[D1:.*]] = tensor.dim %[[COO]], %[[C1]] +// CHECK-WRT: %[[D2:.*]] = tensor.dim %[[COO]], %[[C2]] +// CHECK-WRT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index} +// CHECK-WRT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index} +// CHECK-WRT: %[[I2:.*]] = sparse_tensor.indices %[[COO]] {dimension = 2 : index} +// CHECK-WRT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] +// CHECK-WRT: %[[V:.*]] = sparse_tensor.values %[[COO]] +// CHECK-WRT: sparse_tensor.sort %[[NNZ]], %[[I2]], %[[I0]], %[[I1]] jointly %[[V]] +// CHECK-WRT: %[[T1:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]], %[[D2]]) +// CHECK-WRT: %[[T2:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T1]]) +// CHECK-WRT: ^bb0(%[[LI0:.*]]: index, %[[LI1:.*]]: index, %[[LI2:.*]]: index, %[[LV:.*]]: f32, %[[LT1:.*]]: tensor) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor to tensor + return %0 : tensor +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sorted_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sorted_coo.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sorted_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sorted_coo.mlir @@ -1,4 +1,12 @@ -// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-opt %s --sparse-compiler="enable-runtime-library=true" | \ +// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/wide.mtx" \ +// RUN: TENSOR1="%mlir_src_dir/test/Integration/data/mttkrp_b.tns" \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s --sparse-compiler="enable-runtime-library=false enable-buffer-initialization=true" | \ // RUN: TENSOR0="%mlir_src_dir/test/Integration/data/wide.mtx" \ // RUN: TENSOR1="%mlir_src_dir/test/Integration/data/mttkrp_b.tns" \ // RUN: mlir-cpu-runner \ @@ -66,7 +74,7 @@ func.func @dumpf(%arg0: memref) { %c0 = arith.constant 0 : index - %nan = arith.constant 0x7FF0000001000000 : f64 + %nan = arith.constant 0x0 : f64 %v = vector.transfer_read %arg0[%c0], %nan: memref, vector<20xf64> vector.print %v : vector<20xf64> return @@ -96,7 +104,7 @@ // CHECK: ( 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 0, 0, 0, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0 ) // CHECK-NEXT: ( 0, 126, 127, 254, 1, 253, 2, 0, 1, 3, 98, 126, 127, 128, 249, 253, 255, 0, 0, 0 ) - // CHECK-NEXT: ( -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, 16, -17, nan, nan, nan ) + // CHECK-NEXT: ( -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, 16, -17, 0, 0, 0 ) // %p0 = sparse_tensor.pointers %0 { dimension = 0 : index } : tensor to memref @@ -115,7 +123,7 @@ // CHECK-NEXT: ( 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 0, 1, 1, 2, 3, 98, 126, 126, 127, 127, 128, 249, 253, 253, 254, 255, 0, 0, 0 ) // CHECK-NEXT: ( 0, 3, 1, 3, 2, 3, 3, 0, 3, 0, 3, 3, 3, 1, 3, 0, 3, 0, 0, 0 ) - // CHECK-NEXT: ( -1, 8, -5, -9, -7, 10, -11, 2, 12, -3, -13, 14, -15, 6, 16, 4, -17, nan, nan, nan ) + // CHECK-NEXT: ( -1, 8, -5, -9, -7, 10, -11, 2, 12, -3, -13, 14, -15, 6, 16, 4, -17, 0, 0, 0 ) // %p1 = sparse_tensor.pointers %1 { dimension = 0 : index } : tensor to memref @@ -134,8 +142,8 @@ // CHECK-NEXT: ( 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0 ) // CHECK-NEXT: ( 0, 0, 1, 1, 2, 2, 2, 2, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 0 ) - // CHECK-NEXT: ( 2, 3, 1, 2, 0, 1, 2, 3, 0, 2, 3, 0, 1, 2, 3, 1, 2, 0, 0, 0 ) - // CHECK-NEXT: ( 3, 63, 11, 100, 66, 61, 13, 43, 77, 10, 46, 61, 53, 3, 75, 22, 18, nan, nan, nan ) + // CHECK-NEXT: ( 0, 0, 1, 1, 2, 2, 2, 2, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 0 ) + // CHECK-NEXT: ( 3, 63, 11, 100, 66, 61, 13, 43, 77, 10, 46, 61, 53, 3, 75, 22, 18, 0, 0, 0 ) // %p2 = sparse_tensor.pointers %2 { dimension = 0 : index } : tensor to memref @@ -150,15 +158,15 @@ call @dumpi(%p2) : (memref) -> () call @dumpi(%i20) : (memref) -> () call @dumpi(%i21) : (memref) -> () - call @dumpi(%i22) : (memref) -> () + call @dumpi(%i21) : (memref) -> () call @dumpf(%v2) : (memref) -> () // // CHECK-NEXT: ( 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0 ) // CHECK-NEXT: ( 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0 ) - // CHECK-NEXT: ( 2, 0, 1, 1, 2, 1, 2, 0, 1, 2, 0, 1, 2, 0, 2, 0, 1, 0, 0, 0 ) - // CHECK-NEXT: ( 66, 77, 61, 11, 61, 53, 22, 3, 100, 13, 10, 3, 18, 63, 43, 46, 75, nan, nan, nan ) + // CHECK-NEXT: ( 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0 ) + // CHECK-NEXT: ( 66, 77, 61, 11, 61, 53, 22, 3, 100, 13, 10, 3, 18, 63, 43, 46, 75, 0, 0, 0 ) // %p3 = sparse_tensor.pointers %3 { dimension = 0 : index } : tensor to memref @@ -173,14 +181,14 @@ call @dumpi(%p3) : (memref) -> () call @dumpi(%i30) : (memref) -> () call @dumpi(%i31) : (memref) -> () - call @dumpi(%i32) : (memref) -> () + call @dumpi(%i31) : (memref) -> () call @dumpf(%v3) : (memref) -> () // // CHECK-NEXT: ( 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 1, 2, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 3, 0, 3, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) - // CHECK-NEXT: ( 6, 5, 4, 3, 2, 11, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan ) + // CHECK-NEXT: ( 6, 5, 4, 3, 2, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // %p4 = sparse_tensor.pointers %4 { dimension = 0 : index } : tensor to memref @@ -203,7 +211,7 @@ // CHECK-NEXT: ( 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 1, 2, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 0, 3, 0, 3, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) - // CHECK-NEXT: ( 12, 10, 8, 6, 4, 22, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan ) + // CHECK-NEXT: ( 12, 10, 8, 6, 4, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // %p5 = sparse_tensor.pointers %5 { dimension = 0 : index } : tensor to memref