diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1228,18 +1228,50 @@ concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), concatDim); + bool allDense = false; + Value dstTensor; if (encDst) { - // Start a new COO for the destination tensor. - dst = - params.genBuffers(encDst, sizes, dstTp).genNewCall(Action::kEmptyCOO); - dstPerm = params.getDim2LvlMap(); - elemPtr = genAllocaScalar(rewriter, loc, elemTp); + allDense = llvm::all_of(encDst.getDimLevelType(), + [](DimLevelType dlt) { return isDenseDLT(dlt); }); + // Start a new COO or an initialized annotated all dense sparse tensor. + dst = params.genBuffers(encDst, sizes, dstTp) + .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO); dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + if (allDense) { + dstTensor = dst; + // Get the values buffer for the sparse tensor and reshape it to the + // corresponding dense tensor shape. + dst = genValuesCall(rewriter, loc, + MemRefType::get({ShapedType::kDynamic}, elemTp), + {dst}); + + // Use the dstIdx to store the level sizes. + SmallVector lvlSizes; + for (unsigned i = 0; i < sizes.size(); i++) + lvlSizes.push_back(sizes[toOrigDim(encDst, i)]); + storeIndices(rewriter, loc, rank, dstIdx, lvlSizes); + // The memref ReshapeOp requires the sizes buffer to have a static + // shape. + Value typedBuffer = rewriter.create( + loc, MemRefType::get({rank}, rewriter.getIndexType()), dstIdx); + SmallVector shape(rank, ShapedType::kDynamic); + dst = rewriter.create( + loc, MemRefType::get(shape, elemTp), dst, typedBuffer); + } else { + dstPerm = params.getDim2LvlMap(); + elemPtr = genAllocaScalar(rewriter, loc, elemTp); + } } else { // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. dst = allocDenseTensor(rewriter, loc, dstTp, sizes); } + auto dimIdx2LvlIdx = [&](ValueRange dIdx) -> SmallVector { + SmallVector lIdx; + for (unsigned i = 0; i < dIdx.size(); i++) + lIdx.push_back(dIdx[toOrigDim(encDst, i)]); + return lIdx; + }; for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) { Value orignalOp = std::get<0>(it); // Input (with encoding) from Op Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor @@ -1252,13 +1284,15 @@ Value elemPtr) -> void { auto indVec = loadIndices(builder, loc, rank, idx, concatDim, offset); - if (encDst) { - // Case: sparse => sparse + if (encDst && !allDense) { + // Case: sparse => sparse, except for annotated all dense. storeIndices(builder, loc, rank, dstIdx, indVec); genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx, dstPerm); } else { - // Case: sparse => dense + // Case: sparse => dense, or annotated all dense. + if (allDense) + indVec = dimIdx2LvlIdx(indVec); insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, indVec); } }); @@ -1266,8 +1300,8 @@ genDenseTensorIterationLoop( rewriter, loc, adaptedOp, srcTp, [&](OpBuilder &builder, Location loc, ValueRange idx) -> void { - if (encDst) { - // Case: dense => sparse + if (encDst && !allDense) { + // Case: dense => sparse, except for annotated all dense. storeIndices(builder, loc, rank, dstIdx, idx, concatDim, offset); Value val = genValueForDense(builder, loc, adaptedOp, idx); @@ -1275,12 +1309,14 @@ genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx, dstPerm); } else { - // Case: dense => dense + // Case: dense => dense, or annotated all dense. Value val = genValueForDense(builder, loc, adaptedOp, idx); SmallVector indVec(idx); // Apply offset. indVec[concatDim] = builder.create( loc, indVec[concatDim], offset); + if (allDense) + indVec = dimIdx2LvlIdx(indVec); builder.create(loc, val, dst, indVec); } }); @@ -1295,11 +1331,15 @@ offset = rewriter.create(loc, offset, curDim); } if (encDst) { - // In sparse output case, the destination holds the COO. - Value coo = dst; - dst = params.genNewCall(Action::kFromCOO, coo); - // Release resources. - genDelCOOCall(rewriter, loc, elemTp, coo); + if (!allDense) { + // In sparse output case, the destination holds the COO. + Value coo = dst; + dst = params.genNewCall(Action::kFromCOO, coo); + // Release resources. + genDelCOOCall(rewriter, loc, elemTp, coo); + } else { + dst = dstTensor; + } rewriter.replaceOp(op, dst); } else { rewriter.replaceOpWithNewOp(op, dstTp, dst); diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir @@ -7,6 +7,11 @@ dimOrdering = affine_map<(i,j) -> (j,i)> }> +#SparseMatrix_D_P = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "dense" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + // CHECK-LABEL: func.func @concat_mix_dense( // CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64>, // CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr) @@ -102,9 +107,9 @@ // CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_0]][%[[TMP_c1]]] : memref<2xindex> // CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.null : !llvm.ptr // CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[IotaP_0]], %[[IotaP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c4_i32]], %[[NullPtr]]) -// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref // CHECK: %[[TMP_9:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_10:.*]] = memref.cast %[[TMP_9]] : memref<2xindex> to memref +// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref // CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] { // CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] { // CHECK: memref.store %[[TMP_arg2]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex> @@ -192,9 +197,9 @@ // CHECK-DAG: memref.store %[[TMP_c0]], %[[Dim2Lvl_0]][%[[TMP_c1]]] : memref<2xindex> // CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.null : !llvm.ptr // CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[Lvl2DimP_0]], %[[Dim2LvlP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c4_i32]], %[[NullPtr]]) -// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref // CHECK: %[[TMP_9:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_10:.*]] = memref.cast %[[TMP_9]] : memref<2xindex> to memref +// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref // CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] { // CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] { // CHECK: memref.store %[[TMP_arg2]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex> @@ -367,10 +372,91 @@ // CHECK: call @delSparseTensorIteratorF64(%[[TMP_8]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_12:.*]] = bufferization.to_tensor %[[TMP_1]] : memref // CHECK: return %[[TMP_12]] : tensor -// CHECK: } // CHECK: } func.func @concat_mix_dense_perm_dim1_dyn(%arg0: tensor<3x2xf64>, %arg1: tensor<3x3xf64, #SparseMatrix>) -> tensor { %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index} : tensor<3x2xf64>, tensor<3x3xf64, #SparseMatrix> to tensor return %0 : tensor } + +// CHECK-LABEL: func.func @concat_annotated_dense( +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<4x2xf64>, +// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr) +// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32 +// CHECK-DAG: %[[TMP_c4_i8:.*]] = arith.constant 4 : i8 +// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8 +// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8> +// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref +// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8> +// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8> +// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref +// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex> +// CHECK-DAG: memref.store %[[TMP_c5]], %[[DimSizes_0]][%[[TMP_c1]]] : memref<2xindex> +// CHECK-DAG: %[[LvlSizes_0:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[LvlSizesP_0:.*]] = memref.cast %[[LvlSizes_0]] : memref<2xindex> to memref +// CHECK-DAG: %[[Lvl2Dim_0:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[Lvl2DimP_0:.*]] = memref.cast %[[Lvl2Dim_0]] : memref<2xindex> to memref +// CHECK-DAG: %[[Dim2Lvl_0:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[Dim2LvlP_0:.*]] = memref.cast %[[Dim2Lvl_0]] : memref<2xindex> to memref +// CHECK-DAG: memref.store %[[TMP_c1]], %[[Dim2Lvl_0]][%[[TMP_c0]]] : memref<2xindex> +// CHECK-DAG: memref.store %[[TMP_c0]], %[[Dim2Lvl_0]][%[[TMP_c1]]] : memref<2xindex> +// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[Lvl2DimP_0]], %[[Dim2LvlP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c0_i32]], %[[NullPtr]]) +// CHECK: %[[Values_r:.*]] = call @sparseValuesF64(%[[TMP_7]]) : (!llvm.ptr) -> memref +// CHECK: %[[Values:.*]] = memref.reshape %[[Values_r]] +// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] { +// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] { +// CHECK: %[[TMP_22:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<4x2xf64> +// CHECK: %[[TMP_23:.*]] = arith.cmpf une, %[[TMP_22]], %[[TMP_cst]] : f64 +// CHECK: scf.if %[[TMP_23]] { +// CHECK: memref.store %[[TMP_22]], %[[Values]][%[[TMP_arg3]], %[[TMP_arg2]]] : memref +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8> +// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref +// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8> +// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8> +// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref +// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex> +// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes_1]][%[[TMP_c1]]] : memref<2xindex> +// CHECK-DAG: %[[LvlSizes_1:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[LvlSizesP_1:.*]] = memref.cast %[[LvlSizes_1]] : memref<2xindex> to memref +// CHECK-DAG: %[[Iota_1:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[IotaP_1:.*]] = memref.cast %[[Iota_1]] : memref<2xindex> to memref +// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota_1]][%[[TMP_c0]]] : memref<2xindex> +// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_1]][%[[TMP_c1]]] : memref<2xindex> +// CHECK: %[[TMP_17:.*]] = call @newSparseTensor(%[[DimSizesP_1]], %[[LvlSizesP_1]], %[[LvlTypesP_1]], %[[IotaP_1]], %[[IotaP_1]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]]) +// CHECK: %[[TMP_18:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[TMP_19:.*]] = memref.cast %[[TMP_18]] : memref<2xindex> to memref +// CHECK: %[[TMP_20:.*]] = memref.alloca() : memref +// CHECK: scf.while : () -> () { +// CHECK: %[[TMP_22:.*]] = func.call @getNextF64(%[[TMP_17]], %[[TMP_19]], %[[TMP_20]]) : (!llvm.ptr, memref, memref) -> i1 +// CHECK: scf.condition(%[[TMP_22]]) +// CHECK: } do { +// CHECK: %[[TMP_22:.*]] = memref.load %[[TMP_18]][%[[TMP_c0]]] : memref<2xindex> +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_18]][%[[TMP_c1]]] : memref<2xindex> +// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index +// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_20]][] : memref +// CHECK: memref.store %[[TMP_25]], %[[Values]][%[[TMP_24]], %[[TMP_22]]] : memref +// CHECK: scf.yield +// CHECK: } +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr) -> () +// CHECK: return %[[TMP_7]] : !llvm.ptr +// CHECK: } +func.func @concat_annotated_dense(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3xf64, #SparseMatrix_P>) -> tensor<4x5xf64, #SparseMatrix_D_P> { + %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index} + : tensor<4x2xf64>, tensor<4x3xf64, #SparseMatrix_P> to tensor<4x5xf64, #SparseMatrix_D_P> + return %0 : tensor<4x5xf64, #SparseMatrix_D_P> +}