diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -173,6 +173,12 @@ /// Constructs a new encoding with the dimOrdering and higherOrdering /// reset to the default/identity. SparseTensorEncodingAttr withoutOrdering() const; + + /// Return true if every level is dense in the encoding. + bool isAllDense() const; + + /// Return true if the encoding has an identity dimension ordering. + bool hasIdDimOrdering() const; }]; let genVerifyDecl = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -63,6 +63,14 @@ getPointerBitWidth(), getIndexBitWidth()); } +bool SparseTensorEncodingAttr::isAllDense() const { + return llvm::all_of(getDimLevelType(), isDenseDLT); +} + +bool SparseTensorEncodingAttr::hasIdDimOrdering() const { + return !getDimOrdering() || getDimOrdering().isIdentity(); +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; @@ -172,7 +180,7 @@ } printer << " ]"; // Print remaining members only for non-default values. - if (getDimOrdering() && !getDimOrdering().isIdentity()) + if (!hasIdDimOrdering()) printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; if (getHigherOrdering()) printer << ", higherOrdering = affine_map<" << getHigherOrdering() << ">"; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1349,8 +1349,7 @@ bool allDense = false; Value dstTensor; if (encDst) { - allDense = llvm::all_of(encDst.getDimLevelType(), - [](DimLevelType dlt) { return isDenseDLT(dlt); }); + allDense = encDst.isAllDense(); // Start a new COO or an initialized annotated all dense sparse tensor. dst = params.genBuffers(encDst, sizes, dstTp) .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -525,14 +525,35 @@ // %t = convert_to_dest_tensor(%tmp) SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); Value dst; // Destination tensor for inserting source tensor values. - bool allDense = false; + bool needTmpCOO = true; if (encDst) { - allDense = llvm::all_of(encDst.getDimLevelType(), - [](DimLevelType dlt) { return isDenseDLT(dlt); }); + bool allDense = encDst.isAllDense(); + bool allOrdered = false; + // When concatenating on dimension 0, and all inputs are sorted and have + // an identity dimOrdering, the concatenate will generate coords in + // lexOrder thus no need for the tmp COO buffer. + // TODO: When conDim != 0, as long as conDim is the first dimension + // in all input/output buffers, and all input/output buffers have the same + // dimOrdering, the tmp COO buffer is still unnecessary (e.g, concatenate + // CSC matrices along column). + if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) { + for (auto i : op.getInputs()) { + auto rtp = i.getType().cast(); + auto srcEnc = getSparseTensorEncoding(rtp); + if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) { + allOrdered = true; + continue; + } + allOrdered = false; + break; + } + } + + needTmpCOO = !allDense && !allOrdered; SmallVector dynSizes; getDynamicSizes(dstTp, sizes, dynSizes); RankedTensorType tp = dstTp; - if (!allDense) { + if (needTmpCOO) { tp = getUnorderedCOOFromType(dstTp); encDst = getSparseTensorEncoding(tp); } @@ -596,7 +617,7 @@ if (encDst) { dst = rewriter.create(loc, dst, true); - if (!allDense) { + if (needTmpCOO) { Value tmpCoo = dst; dst = rewriter.create(loc, dstTp, tmpCoo).getResult(); rewriter.create(loc, tmpCoo); diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir @@ -79,8 +79,7 @@ // CHECK: scf.yield %[[RET_6]] // CHECK: } // CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts -// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor -// CHECK: return %[[TMP_22]] : tensor<9x4xf64, #sparse_tensor +// CHECK: return %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>, %arg1: tensor<3x4xf64, #DCSR>, %arg2: tensor<4x4xf64, #DCSR>) @@ -166,8 +165,7 @@ // CHECK: scf.yield %[[RET_6]] // CHECK: } // CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts -// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor, %arg1: tensor<3x4xf64, #DCSR>, %arg2: tensor<4x4xf64, #DCSR>)