diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -56,6 +56,7 @@ "AffineDialect", "arith::ArithmeticDialect", "bufferization::BufferizationDialect", + "linalg::LinalgDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -110,6 +110,26 @@ return isZeroValue(yieldOp.getOperand(0)); } +static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) { + auto *ctx = src.getContext(); + auto rank = src.getRank(); + SmallVector dims; + // An unordered and non-unique compressed dim at beginning. + dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo); + // TODO: it is actually ordered at the level for ordered input. + // Followed by unordered non-unique n-2 singleton levels. + std::fill_n(std::back_inserter(dims), rank - 2, + SparseTensorEncodingAttr::DimLevelType::SingletonNuNo); + // Ends by a unordered unique singleton level. + dims.push_back(SparseTensorEncodingAttr::DimLevelType::SingletonNo); + // TODO: Maybe pick the bitwidth based on input/output tensors (probably the + // largest one among them) in the original operation instead of using the + // default value. + auto enc = SparseTensorEncodingAttr::get( + ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), 0, 0); + return RankedTensorType::get(src.getShape(), src.getElementType(), enc); +} + //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// @@ -279,7 +299,8 @@ auto convert = rewriter.create(loc, denseTp, op.getSrc()); op->setOperand(0, convert); return success(); - } else if (encDst) { + } + if (encDst) { RankedTensorType rtp = op.getResult().getType().template cast(); auto denseTp = @@ -294,6 +315,66 @@ } }; +struct ConcatenateRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto rtp = op.getType().cast(); + // TODO: Build the output shape if needed. + assert(rtp.hasStaticShape()); + auto rank = rtp.getRank(); + size_t conDim = op.getDimension().getZExtValue(); + // %t = concatenate %s1, %s2, %s3 {dim = 1} + // ==> + // %tmp = bufferization.alloc_tensor : unordered COO + // %1 = generic(ins: s1, outs : %tmp) : (d0, d1) -> (d0, d1) + // %2 = generic(ins: s2, outs : %1) : (d0, d1) -> (d0, d1+dim(s1)) + // %3 = generic(ins: s3, outs : %2) : (d0, d1) -> (d0, d1+dim(s1)+dim(s2)) + // %t = sparse_tensor.cast %tmp + auto cooTp = getUnorderedCOOFromType(rtp); + auto cooBuffer = + rewriter.create(loc, cooTp, ValueRange()).getResult(); + + size_t offset = 0; + AffineMap idMap = AffineMap::getMultiDimIdentityMap(rank, ctx); + MutableAffineMap outMap(idMap); + // All parallel iterator. + SmallVector its(rank, getParallelIteratorTypeName()); + auto itAttr = rewriter.getStrArrayAttr(its); + for (Value input : op.getInputs()) { + // d_con = d_con + offset + auto conExp = getAffineBinaryOpExpr(AffineExprKind::Add, + getAffineDimExpr(conDim, ctx), + getAffineConstantExpr(offset, ctx)); + outMap.setResult(conDim, conExp); + // Builds the indexing map. + auto idxMap = + rewriter.getAffineMapArrayAttr({idMap, outMap.getAffineMap()}); + // Build a generic op for each input tensor to append new values into the + // output tensor. + cooBuffer = + rewriter + .create(loc, cooTp, input, cooBuffer, idxMap, itAttr, + /*doc=*/nullptr, /*library_call=*/nullptr, + [&](OpBuilder &b, Location loc, ValueRange v) { + b.create(loc, v[0]); + }) + .getResult(0); + + // Accumlate the offset. Note that only static-shaped input are allowed by + // concatenate op verifier, which saves us from computing the offset + // dynamically. + auto d = input.getType().cast().getShape()[conDim]; + assert(!ShapedType::isDynamic(d)); + offset += d; + } + rewriter.replaceOpWithNewOp(op, rtp, cooBuffer); + return success(); + } +}; + } // namespace //===---------------------------------------------------------------------===// @@ -301,9 +382,11 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, - bool /*enableRT*/) { + bool enableRT) { patterns.add, ReshapeRewriter>(patterns.getContext()); - // TODO: If RT not enabled, rewrite concatenate ops, etc here. + // If RT not enabled, rewrite concatenate ops, etc here. + if (!enableRT) + patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -176,7 +176,7 @@ /// same index is used more than once. Also rejects compound affine /// expressions in sparse dimensions. static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, - DimLevelFormat dim) { + DimLevelFormat dim, bool isOutput) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); @@ -187,14 +187,15 @@ } case AffineExprKind::Add: case AffineExprKind::Mul: { - if (dim.levelType != DimLvlType::kDense) + if (dim.levelType != DimLvlType::kDense && !isOutput) return false; // compound only in dense dim auto binOp = a.cast(); - return findAffine(merger, tensor, binOp.getLHS(), dim) && - findAffine(merger, tensor, binOp.getRHS(), dim); + return findAffine(merger, tensor, binOp.getLHS(), dim, isOutput) && + findAffine(merger, tensor, binOp.getRHS(), dim, isOutput); } case AffineExprKind::Constant: - return dim.levelType == DimLvlType::kDense; // const only in dense dim + return dim.levelType == DimLvlType::kDense || + isOutput; // const only in dense dim default: return false; } @@ -216,7 +217,8 @@ for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned tensor = t->getOperandNumber(); AffineExpr a = map.getResult(perm(enc, d)); - if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) + if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d), + op.isOutputTensor(t))) return false; // inadmissable affine expression } } @@ -347,10 +349,28 @@ /// Returns true if tensor materializes uninitialized into the computation. static bool isMaterializing(Value val) { + if (auto l = val.getDefiningOp()) + return !l.getHasInserts(); return val.getDefiningOp() || val.getDefiningOp(); } +/// Return true if tensor construction has not been finalized yet (it is not the +/// last GenericOp on the chain). +static bool isFinalizing(Value val) { + if (!val.hasOneUse()) + return true; + auto op = dyn_cast(*val.getUsers().begin()); + // We are in the middle of constructing the tensor if there is a single + // dataflow between multiple GenericOps. + // E.g., + // %x = alloc_tensor + // %s = generic (out: %x) => not finalized + // %t = generic (out: %s) => finalized + // return %t + return !(op && op.getOutputs().size() == 1 && op.getOutputs()[0] == val); +} + /// Returns true when the tensor expression is admissable for codegen. /// Since all sparse input tensors are admissable, we just need to check /// whether the out tensor in the tensor expression codegen is admissable. @@ -572,7 +592,13 @@ codegen.indices[tensor][idx] = builder.create(loc, indTp, t->get(), dim); } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) { - llvm_unreachable("TODO: not implemented yet"); + // Singleton dimension, fetch indices. + auto dynShape = {ShapedType::kDynamicSize}; + auto indTp = + MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + auto dim = builder.getIndexAttr(d); + codegen.indices[tensor][idx] = + builder.create(loc, indTp, t->get(), dim); } // Find upper bound in current dimension. unsigned p = perm(enc, d); @@ -1135,6 +1161,12 @@ if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || at != codegen.outerParNest) return; // not needed at this level + if (llvm::any_of(op.getTiedIndexingMap(lhs).getResults(), + [](AffineExpr ae) { return !ae.isa(); })) { + // TODO: it could be supported. + assert(false && "Does not support expansion on sparse output with complex " + "indexing map."); + } // Generate start or end of an expanded access pattern. Value tensor = lhs->get(); Location loc = op.getLoc(); @@ -1490,9 +1522,12 @@ // Move the insertion indices in lexicographic index order. During access // pattern expansion, we can skip setting the innermost dimension. if (codegen.sparseOut && !codegen.expValues) { + auto enc = getSparseTensorEncoding(codegen.sparseOut->get().getType()); + AffineExpr a = + op.getTiedIndexingMap(codegen.sparseOut).getResult(perm(enc, at)); Value pos = constantIndex(builder, loc, at); - builder.create(loc, codegen.loops[idx], codegen.lexIdx, - pos); + builder.create(loc, genAffine(codegen, builder, a, loc), + codegen.lexIdx, pos); } } @@ -1785,8 +1820,11 @@ if (getSparseTensorEncoding(resType)) { // The sparse tensor rematerializes from the original sparse tensor's // underlying sparse storage format. + bool isFinal = isFinalizing(op.getResult(0)); + // We only set hasInserter attribute when we are finalizing the sparse + // output tensor. rewriter.replaceOpWithNewOp(op, resType, lhs->get(), - codegen.sparseOut == lhs); + codegen.sparseOut == lhs && isFinal); } else { // To rematerialize an non-annotated tensor, simply load it // from the bufferized value. diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir @@ -0,0 +1,98 @@ +// RUN: mlir-opt %s --sparsification=enable-runtime-library=false | FileCheck %s + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +// CHECK-LABEL: func @concat_sparse_sparse( +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64 +// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64 +// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64 +// CHECK: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK: %[[TMP_c5:.*]] = arith.constant 5 : index +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<9x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu-no", "singleton-no" ] +// CHECK: %[[TMP_1:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index} +// CHECK: %[[TMP_2:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index} +// CHECK: %[[TMP_3:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index} +// CHECK: %[[TMP_4:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index} +// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] +// CHECK: %[[TMP_6:.*]] = memref.alloca(%[[TMP_c2]]) : memref +// CHECK: %[[TMP_7:.*]] = memref.alloca() : memref +// CHECK: %[[TMP_8:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_9:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref +// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_8]] to %[[TMP_9]] step %[[TMP_c1]] { +// CHECK: %[[TMP_32:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref +// CHECK: memref.store %[[TMP_32]], %[[TMP_6]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_33:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_34:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_35:.*]] = memref.load %[[TMP_3]][%[[TMP_34]]] : memref +// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_33]] to %[[TMP_35]] step %[[TMP_c1]] { +// CHECK: %[[TMP_36:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref +// CHECK: memref.store %[[TMP_36]], %[[TMP_6]][%[[TMP_c1]]] : memref +// CHECK: %[[TMP_37:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref +// CHECK: memref.store %[[TMP_37]], %[[TMP_7]][] : memref +// CHECK: sparse_tensor.insert %[[TMP_0]], %[[TMP_6]], %[[TMP_7]] +// CHECK: } +// CHECK: } +// CHECK: %[[TMP_10:.*]] = sparse_tensor.load %[[TMP_0]] +// CHECK: %[[TMP_11:.*]] = sparse_tensor.pointers %[[TMP_arg1]] {dimension = 0 : index} +// CHECK: %[[TMP_12:.*]] = sparse_tensor.indices %[[TMP_arg1]] {dimension = 0 : index} +// CHECK: %[[TMP_13:.*]] = sparse_tensor.pointers %[[TMP_arg1]] {dimension = 1 : index} +// CHECK: %[[TMP_14:.*]] = sparse_tensor.indices %[[TMP_arg1]] {dimension = 1 : index} +// CHECK: %[[TMP_15:.*]] = sparse_tensor.values %[[TMP_arg1]] +// CHECK: %[[TMP_16:.*]] = memref.alloca(%[[TMP_c2]]) : memref +// CHECK: %[[TMP_17:.*]] = memref.alloca() : memref +// CHECK: %[[TMP_18:.*]] = memref.load %[[TMP_11]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_19:.*]] = memref.load %[[TMP_11]][%[[TMP_c1]]] : memref +// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_18]] to %[[TMP_19]] step %[[TMP_c1]] { +// CHECK: %[[TMP_32:.*]] = memref.load %[[TMP_12]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_33:.*]] = arith.addi %[[TMP_32]], %[[TMP_c2]] : index +// CHECK: memref.store %[[TMP_33]], %[[TMP_16]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_34:.*]] = memref.load %[[TMP_13]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_35:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_36:.*]] = memref.load %[[TMP_13]][%[[TMP_35]]] : memref +// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_34]] to %[[TMP_36]] step %[[TMP_c1]] { +// CHECK: %[[TMP_37:.*]] = memref.load %[[TMP_14]][%[[TMP_arg4]]] : memref +// CHECK: memref.store %[[TMP_37]], %[[TMP_16]][%[[TMP_c1]]] : memref +// CHECK: %[[TMP_38:.*]] = memref.load %[[TMP_15]][%[[TMP_arg4]]] : memref +// CHECK: memref.store %[[TMP_38]], %[[TMP_17]][] : memref +// CHECK: sparse_tensor.insert %[[TMP_10]], %[[TMP_16]], %[[TMP_17]] +// CHECK: } +// CHECK: } +// CHECK: %[[TMP_20:.*]] = sparse_tensor.load %[[TMP_10]] +// CHECK: %[[TMP_21:.*]] = sparse_tensor.pointers %[[TMP_arg2]] {dimension = 0 : index} +// CHECK: %[[TMP_22:.*]] = sparse_tensor.indices %[[TMP_arg2]] {dimension = 0 : index} +// CHECK: %[[TMP_23:.*]] = sparse_tensor.pointers %[[TMP_arg2]] {dimension = 1 : index} +// CHECK: %[[TMP_24:.*]] = sparse_tensor.indices %[[TMP_arg2]] {dimension = 1 : index} +// CHECK: %[[TMP_25:.*]] = sparse_tensor.values %[[TMP_arg2]] +// CHECK: %[[TMP_26:.*]] = memref.alloca(%[[TMP_c2]]) : memref +// CHECK: %[[TMP_27:.*]] = memref.alloca() : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_21]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_29:.*]] = memref.load %[[TMP_21]][%[[TMP_c1]]] : memref +// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_28]] to %[[TMP_29]] step %[[TMP_c1]] { +// CHECK: %[[TMP_32:.*]] = memref.load %[[TMP_22]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_33:.*]] = arith.addi %[[TMP_32]], %[[TMP_c5]] : index +// CHECK: memref.store %[[TMP_33]], %[[TMP_26]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_34:.*]] = memref.load %[[TMP_23]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_35:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_36:.*]] = memref.load %[[TMP_23]][%[[TMP_35]]] : memref +// CHECK: scf.for %[[TMP_arg4:.*]] = %[[TMP_34]] to %[[TMP_36]] step %[[TMP_c1]] { +// CHECK: %[[TMP_37:.*]] = memref.load %[[TMP_24]][%[[TMP_arg4]]] : memref +// CHECK: memref.store %[[TMP_37]], %[[TMP_26]][%[[TMP_c1]]] : memref +// CHECK: %[[TMP_38:.*]] = memref.load %[[TMP_25]][%[[TMP_arg4]]] : memref +// CHECK: memref.store %[[TMP_38]], %[[TMP_27]][] : memref +// CHECK: sparse_tensor.insert %[[TMP_20]], %[[TMP_26]], %[[TMP_27]] +// CHECK: } +// CHECK: } +// CHECK: %[[TMP_30:.*]] = sparse_tensor.load %[[TMP_20]] hasInserts +// CHECK: %[[TMP_31:.*]] = sparse_tensor.convert %[[TMP_30]] +// CHECK: return %[[TMP_31]] +func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #MAT_C_C>, + %arg1: tensor<3x4xf64, #MAT_C_C>, + %arg2: tensor<4x4xf64, #MAT_C_C>) + -> tensor<9x4xf64, #MAT_C_C> { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #MAT_C_C>, + tensor<3x4xf64, #MAT_C_C>, + tensor<4x4xf64, #MAT_C_C> to tensor<9x4xf64, #MAT_C_C> + return %0 : tensor<9x4xf64, #MAT_C_C> +}