diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -56,6 +56,7 @@ "AffineDialect", "arith::ArithmeticDialect", "bufferization::BufferizationDialect", + "linalg::LinalgDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -110,6 +110,24 @@ return isZeroValue(yieldOp.getOperand(0)); } +static RankedTensorType getUnorderCOOFromType(RankedTensorType src) { + auto *ctx = src.getContext(); + auto rank = src.getRank(); + SmallVector dims; + // An unordered and non-unique compressed dim at beginning. + dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo); + // Followed by unordered singleton levels. + // TODO: it is actually ordered at the level for ordered input + std::fill_n(std::back_inserter(dims), rank - 1, + SparseTensorEncodingAttr::DimLevelType::SingletonNo); + // TODO: Maybe pick the bitwidth based on input/output tensors (probably the + // largest one among them) in the original operation instead of using a + // constant value. + auto enc = SparseTensorEncodingAttr::get( + ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), 64, 64); + return RankedTensorType::get(src.getShape(), src.getElementType(), enc); +} + //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// @@ -279,7 +297,8 @@ auto convert = rewriter.create(loc, denseTp, op.getSrc()); op->setOperand(0, convert); return success(); - } else if (encDst) { + } + if (encDst) { RankedTensorType rtp = op.getResult().getType().template cast(); auto denseTp = @@ -294,6 +313,66 @@ } }; +struct ConcatenateRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter &rewriter) const override { + // TODO: Build the output shape if needed. + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto rtp = op.getType().cast(); + assert(rtp.hasStaticShape()); + auto rank = rtp.getRank(); + size_t conDim = op.getDimension().getZExtValue(); + // %t = concatenate %s1, %s2, %s3 {dim = 1} + // ==> + // %tmp = bufferization.alloc_tensor : unorder COO + // %1 = generic(ins: s1, outs : %tmp) : (d0, d1) -> (d0, d1) + // %2 = generic(ins: s2, outs : %1) : (d0, d1) -> (d0, d1+dim(s1)) + // %3 = generic(ins: s3, outs : %2) : (d0, d1) -> (d0, d1+dim(s1)+dim(s2)) + // %t = sparse_tensor.cast %tmp + auto cooTp = getUnorderCOOFromType(rtp); + auto cooBuffer = + rewriter.create(loc, cooTp, ValueRange()).getResult(); + + size_t offset = 0; + AffineMap idMap = AffineMap::getMultiDimIdentityMap(rank, ctx); + MutableAffineMap outMap(idMap); + // All parallel iterator. + SmallVector its(rank, getParallelIteratorTypeName()); + auto itAttr = rewriter.getStrArrayAttr(its); + for (Value input : op.getInputs()) { + // d_con = d_con + offset + auto conExp = getAffineBinaryOpExpr(AffineExprKind::Add, + getAffineDimExpr(conDim, ctx), + getAffineConstantExpr(offset, ctx)); + outMap.setResult(conDim, conExp); + // Builds the indexing map + auto idxMap = + rewriter.getAffineMapArrayAttr({idMap, outMap.getAffineMap()}); + // Build a generic op for each input tensor to append new values into the + // output tensor. + cooBuffer = + rewriter + .create(loc, cooTp, input, cooBuffer, idxMap, itAttr, + /*doc=*/nullptr, /*library_call=*/nullptr, + [&](OpBuilder &b, Location loc, ValueRange v) { + b.create(loc, v[0]); + }) + .getResult(0); + + // Accumlate the offset. Note that only static-shaped input are allowed by + // concatenate op verifier, which saves us from computing the offset + // dynamically. + auto d = input.getType().cast().getShape()[conDim]; + assert(!ShapedType::isDynamic(d)); + offset += d; + } + rewriter.replaceOpWithNewOp(op, rtp, cooBuffer); + return success(); + } +}; + } // namespace //===---------------------------------------------------------------------===// @@ -301,9 +380,11 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, - bool /*enableRT*/) { + bool enableRT) { patterns.add, ReshapeRewriter>(patterns.getContext()); - // TODO: If RT not enabled, rewrite concatenate ops, etc here. + // If RT not enabled, rewrite concatenate ops, etc here. + if (!enableRT) + patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -176,7 +176,7 @@ /// same index is used more than once. Also rejects compound affine /// expressions in sparse dimensions. static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, - DimLevelFormat dim) { + DimLevelFormat dim, bool isOutput) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); @@ -187,14 +187,15 @@ } case AffineExprKind::Add: case AffineExprKind::Mul: { - if (dim.levelType != DimLvlType::kDense) + if (dim.levelType != DimLvlType::kDense && !isOutput) return false; // compound only in dense dim auto binOp = a.cast(); - return findAffine(merger, tensor, binOp.getLHS(), dim) && - findAffine(merger, tensor, binOp.getRHS(), dim); + return findAffine(merger, tensor, binOp.getLHS(), dim, isOutput) && + findAffine(merger, tensor, binOp.getRHS(), dim, isOutput); } case AffineExprKind::Constant: - return dim.levelType == DimLvlType::kDense; // const only in dense dim + return dim.levelType == DimLvlType::kDense || + isOutput; // const only in dense dim default: return false; } @@ -216,7 +217,8 @@ for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t->getOperandNumber(); AffineExpr a = map.getResult(perm(enc, d)); - if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) + if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d), + op.isOutputTensor(t))) return false; // inadmissable affine expression } } @@ -347,10 +349,28 @@ /// Returns true if tensor materializes uninitialized into the computation. static bool isMaterializing(Value val) { + if (auto l = val.getDefiningOp()) + return !l.getHasInserts(); return val.getDefiningOp() || val.getDefiningOp(); } +/// Return true if tensor construction has not been finalized yet (it is not the +/// last GenericOp on the chain). +static bool isFinalizing(Value val) { + if (!val.hasOneUse()) + return true; + auto op = dyn_cast(*val.getUsers().begin()); + // We are in the middle of constructing the tensor if there is a single + // dataflow between multiple GenericOps. + // E.g., + // %x = alloc_tensor + // %s = generic (out: %x) => not finalized + // %t = generic (out: %s) => finalized + // return %t + return !(op && op.getOutputs().size() == 1 && op.getOutputs()[0] == val); +} + /// Returns true when the tensor expression is admissable for codegen. /// Since all sparse input tensors are admissable, we just need to check /// whether the out tensor in the tensor expression codegen is admissable. @@ -572,7 +592,13 @@ codegen.indices[tensor][idx] = builder.create(loc, indTp, t->get(), dim); } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) { - llvm_unreachable("TODO: not implemented yet"); + // Singleton dimension, fetch indices. + auto dynShape = {ShapedType::kDynamicSize}; + auto indTp = + MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + auto dim = builder.getIndexAttr(d); + codegen.indices[tensor][idx] = + builder.create(loc, indTp, t->get(), dim); } // Find upper bound in current dimension. unsigned p = perm(enc, d); @@ -1135,6 +1161,12 @@ if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || at != codegen.outerParNest) return; // not needed at this level + if (llvm::any_of(op.getTiedIndexingMap(lhs).getResults(), + [](AffineExpr ae) { return !ae.isa(); })) { + // TODO: it could be supported. + assert(false && "Does not support expansion on sparse output with complex " + "indexing map."); + } // Generate start or end of an expanded access pattern. Value tensor = lhs->get(); Location loc = op.getLoc(); @@ -1490,9 +1522,12 @@ // Move the insertion indices in lexicographic index order. During access // pattern expansion, we can skip setting the innermost dimension. if (codegen.sparseOut && !codegen.expValues) { + auto enc = getSparseTensorEncoding(codegen.sparseOut->get().getType()); + AffineExpr a = + op.getTiedIndexingMap(codegen.sparseOut).getResult(perm(enc, at)); Value pos = constantIndex(builder, loc, at); - builder.create(loc, codegen.loops[idx], codegen.lexIdx, - pos); + builder.create(loc, genAffine(codegen, builder, a, loc), + codegen.lexIdx, pos); } } @@ -1785,8 +1820,11 @@ if (getSparseTensorEncoding(resType)) { // The sparse tensor rematerializes from the original sparse tensor's // underlying sparse storage format. + bool isFinal = isFinalizing(op.getResult(0)); + // We only set hasInserter attribute when we are finalizing the sparse + // output tensor. rewriter.replaceOpWithNewOp(op, resType, lhs->get(), - codegen.sparseOut == lhs); + codegen.sparseOut == lhs && isFinal); } else { // To rematerialize an non-annotated tensor, simply load it // from the bufferized value.