diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -681,14 +681,21 @@ // COO tensor. // TODO: enhance foreachOp to take ordering to remove the need of a // temporary COO tensor here. - const RankedTensorType bufferTp = dstTp.isIdentity() + const RankedTensorType bufferTp = dstTp.isIdentity() || fromSparseConst ? dstTp.getRankedTensorType() : getUnorderedCOOFromTypeWithOrdering( dstTp, dstTp.getDimToLvlMap()); + // Only imposes foreach order on dense constant (which will be statically + // sorted by the sparse compiler), otherwise the rotated loop sequence + // results to bad cache locality. + AffineMapAttr foreachOrder = nullptr; + if (encDst.getDimOrdering() && fromSparseConst) + foreachOrder = AffineMapAttr::get(encDst.getDimOrdering()); + auto buffer = rewriter.create(loc, bufferTp, dynSizes).getResult(); auto foreachOp = rewriter.create( - loc, src, buffer, + loc, src, buffer, foreachOrder, [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, ValueRange reduc) { Value input = reduc.front(); @@ -795,7 +802,6 @@ // tensor (e.g., src tensor is not ordered or src tensor haves a different // dimOrdering). if (const SparseTensorType srcTp(srcRTT); - !isUniqueCOOType(srcRTT) && !(srcTp.isAllOrdered() && srcTp.hasSameDimToLvlMap(dstTp))) { // Construct a COO tensor from the src tensor. // TODO: there may be cases for which more efficiently without diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -183,30 +183,17 @@ return %1 : tensor<8x7xf32, #CSR> } -// CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> { -// CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> -// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[F0]] init(%[[T0]]) -// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32, %[[L0T:.*]]: tensor -// CHECK-RWT: %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I1]], %[[L0I0]]] -// CHECK-RWT: sparse_tensor.yield %[[L0T2]] -// CHECK-RWT: } -// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts -// CHECK-RWT: %[[NSE:.*]] = sparse_tensor.number_of_entries %[[COO]] -// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: %[[I:.*]] = sparse_tensor.coordinates_buffer %[[COO]] -// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[NSE]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index} -// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) -// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor -// CHECK-RWT: %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I1]], %[[L1I0]]] -// CHECK-RWT: sparse_tensor.yield %[[L1T2]] -// CHECK-RWT: } -// CHECK-RWT: %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts -// CHECK-RWT: %[[T6:.*]] = sparse_tensor.convert %[[T5]] -// CHECK-RWT: bufferization.dealloc_tensor %[[COO]] -// CHECK-RWT: return %[[T6]] -// CHECK-RWT: } +// CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, +// CHECK-RWT: %[[VAL_0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> +// CHECK-RWT: %[[VAL_1:.*]] = bufferization.alloc_tensor() : +// CHECK-RWT: %[[VAL_2:.*]] = sparse_tensor.foreach in %[[VAL_0]] init(%[[VAL_1]]) {order = #map} : tensor<8x7xf32>, +// CHECK-RWT: ^bb0(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: tensor +// CHECK-RWT: %[[VAL_7:.*]] = sparse_tensor.insert %[[VAL_5]] into %[[VAL_6]]{{\[}}%[[VAL_4]], %[[VAL_3]]] : +// CHECK-RWT: sparse_tensor.yield %[[VAL_7]] : +// CHECK-RWT: } +// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.load %[[VAL_9:.*]] hasInserts : +// CHECK-RWT: return %[[VAL_8]] : +// CHECK-RWT: } func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{ // Initialize a tensor. %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32> diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -153,22 +153,34 @@ } // CHECK-RWT-LABEL: func.func @sparse_convert_permuted( -// CHECK-RWT-SAME: %[[COO:.*]]: -// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-RWT-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-RWT: %[[D0:.*]] = tensor.dim %[[COO]], %[[C0]] -// CHECK-RWT: %[[D1:.*]] = tensor.dim %[[COO]], %[[C1]] -// CHECK-RWT: %[[D2:.*]] = tensor.dim %[[COO]], %[[C2]] -// CHECK-RWT: %[[T1:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]], %[[D2]]) -// CHECK-RWT: %[[T2:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T1]]) -// CHECK-RWT: ^bb0(%[[LI0:.*]]: index, %[[LI1:.*]]: index, %[[LI2:.*]]: index, %[[LV:.*]]: f32, %[[LT1:.*]]: tensor>) -> tensor> { +// CHECK-RWT-DAG: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK-RWT-DAG: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK-RWT-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] +// CHECK-RWT-DAG: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] +// CHECK-RWT-DAG: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] +// CHECK-RWT-DAG: %[[VAL_7:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] +// CHECK-RWT: %[[VAL_8:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]] +// CHECK-RWT: %[[VAL_9:.*]] = sparse_tensor.foreach in %[[VAL_0]] init(%[[VAL_8]]) +// CHECK-RWT: ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index, %[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: tensor>): +// CHECK-RWT: %[[VAL_15:.*]] = sparse_tensor.insert %[[VAL_13]] into %[[VAL_14]]{{\[}}%[[VAL_12]], %[[VAL_10]], %[[VAL_11]]] +// CHECK-RWT: sparse_tensor.yield %[[VAL_15]] : tensor> // CHECK-RWT: } -// CHECK-RWT: %[[T3:.*]] = sparse_tensor.load %[[T2:.*]] hasInserts -// CHECK-RWT: %[[T4:.*]] = sparse_tensor.convert %[[T3]] -// CHECK-RWT: return %[[T4]] +// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor> +// CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor> to memref +// CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor> to memref +// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {nx = 3 : index, ny = 0 : index} +// CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]] +// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]]) +// CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor>): +// CHECK-RWT: %[[VAL_27:.*]] = sparse_tensor.insert %[[VAL_25]] into %[[VAL_26]]{{\[}}%[[VAL_24]], %[[VAL_22]], %[[VAL_23]]] +// CHECK-RWT: sparse_tensor.yield %[[VAL_27]] +// CHECK-RWT: } +// CHECK-RWT: bufferization.dealloc_tensor %[[VAL_16]] +// CHECK-RWT: %[[VAL_28:.*]] = sparse_tensor.load %[[VAL_29:.*]] hasInserts +// CHECK-RWT: %[[VAL_30:.*]] = sparse_tensor.convert %[[VAL_28]] +// CHECK-RWT: return %[[VAL_30]] func.func @sparse_convert_permuted(%arg0: tensor) -> tensor { %0 = sparse_tensor.convert %arg0 : tensor to tensor return %0 : tensor