diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -149,9 +149,11 @@ // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the // default value. + unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0; + unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0; auto enc = SparseTensorEncodingAttr::get( ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(), - encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + pointerBitWidth, indexBitWidth); return RankedTensorType::get(src.getShape(), src.getElementType(), enc); } @@ -428,10 +430,23 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto rtp = op.getType().cast(); - // TODO: Build the output shape if needed. - assert(rtp.hasStaticShape()); - auto rank = rtp.getRank(); size_t conDim = op.getDimension().getZExtValue(); + SmallVector dynSizes; + if (!rtp.hasStaticShape()) { + ArrayRef rShape = rtp.getShape(); + for (const auto &d : llvm::enumerate(rShape)) { + if (d.value() == ShapedType::kDynamicSize) { + Value v = + rewriter.create(loc, op.getOperand(0), d.index()); + for (const auto &opnd : op.getOperands().drop_front()) { + Value t = rewriter.create(loc, opnd, d.index()); + v = rewriter.create(loc, v, t); + } + dynSizes.push_back(v); + } + } + } + // %t = concatenate %s1, %s2, %s3 {dim = 1} // ==> // %tmp = bufferization.alloc_tensor : unordered COO @@ -441,13 +456,11 @@ // %t = sparse_tensor.cast %tmp auto cooTp = getUnorderedCOOFromType(rtp); auto cooBuffer = - rewriter.create(loc, cooTp, ValueRange()).getResult(); - + rewriter.create(loc, cooTp, dynSizes).getResult(); + auto rank = rtp.getRank(); Value offset = constantIndex(rewriter, loc, 0); ForeachOp foreachOp; for (Value input : op.getInputs()) { - // Builds the indexing map. - // Build a for op for each input tensor to append new values into the // output tensor. foreachOp = rewriter.create( @@ -462,8 +475,16 @@ idx = builder.create(loc, idx, offset); indices.push_back(idx); } - auto t = builder.create(loc, v, reduc.front(), indices); - builder.create(loc, t); + Value cond = genIsNonzero(rewriter, loc, v); + scf::IfOp ifOp = builder.create( + loc, TypeRange(reduc.front().getType()), cond, /*else*/ true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value t = builder.create(loc, v, reduc.front(), indices); + rewriter.create(loc, t); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, reduc.front()); + rewriter.setInsertionPointAfter(ifOp); + rewriter.create(loc, ifOp.getResult(0)); }); // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset @@ -659,7 +680,7 @@ for (uint64_t i = 0; i < rank; i++) { uint64_t orgDim = toOrigDim(encSrc, i); xs[toStoredDim(encDst, orgDim)] = rewriter.create( - loc, indTp, src, rewriter.getIndexAttr(orgDim)); + loc, indTp, src, rewriter.getIndexAttr(i)); } // Retrieve NNZ. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir @@ -1,4 +1,11 @@ -// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=true | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s +// +// Do the same run, but now with direct IR generation. +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ // RUN: mlir-cpu-runner \ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \