diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -111,6 +111,20 @@ return isZeroValue(yieldOp.getOperand(0)); } +/// Populates given sizes array from type (for static sizes) and from +/// the tensor (for dynamic sizes). +static void sizesForTensor(OpBuilder &builder, SmallVector &sizes, + Location loc, ShapedType stp, Value tensor) { + for (const auto &d : enumerate(stp.getShape())) { + Value dim; + if (d.value() == ShapedType::kDynamicSize) + dim = builder.create(loc, tensor, d.index()); + else + dim = constantIndex(builder, loc, d.value()); + sizes.push_back(dim); + } +} + // TODO: The dim level property of the COO type relies on input tensors, the // shape relies on the output tensor // Helpers to setup a COO type. @@ -119,8 +133,10 @@ auto rank = src.getRank(); SmallVector dims; - // An unordered and non-unique compressed dim at beginning. - dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo); + // An unordered and non-unique compressed dim at beginning unless the tensor + // is a 1D tensor. + if (rank > 1) + dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo); // TODO: it is actually ordered at the level for ordered input. // Followed by unordered non-unique n-2 singleton levels. std::fill_n(std::back_inserter(dims), rank - 2, @@ -280,24 +296,91 @@ } }; -/// Sparse rewriting rule for reshape operator. +/// Sparse rewriting rule for reshape operators. template struct ReshapeRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + ReshapeRewriter(MLIRContext *context, bool rt); + + static LogicalResult rewriteSparse2SparseReshape(ReshapeOp op, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value srcTensor = op.getSrc(); + auto srcTp = srcTensor.getType().template cast(); + auto dstTp = op.getResult().getType().template cast(); + SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); + SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + assert(encDst && encSrc); + + Type elemTp = srcTp.getElementType(); + assert(elemTp == dstTp.getElementType() && + "reshape should not change element type"); + + // Generate code to represent the static dimension constants or compute + // the dynamic dimension values. + SmallVector srcSizes; + sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); + SmallVector dstSizes; + SmallVector dstDynSizes; + if (dstTp.hasStaticShape()) { + for (auto d : dstTp.getShape()) + dstSizes.push_back(constantIndex(rewriter, loc, d)); + } else { + ArrayRef dstShape = dstTp.getShape(); + genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape, + op.getReassociationIndices()); + for (auto &d : llvm::enumerate(dstShape)) { + if (d.value() == ShapedType::kDynamicSize) + dstDynSizes.push_back(dstSizes[d.index()]); + } + } + + // Implement the sparse2sparse reshape as follows: + // %tmp = bufferization.alloc_tensor : unordered COO + // foreach srcCoords %srcTensor + // insert translateIndicesArray(srcCoords), %tmp + // %t = sparse_tensor.cast %tmp + RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); + auto cooBuffer = + rewriter.create(loc, cooTp, dstDynSizes).getResult(); + rewriter.create( + loc, srcTensor, [&](OpBuilder &builder, Location loc, ValueRange args) { + SmallVector srcIndices; + SmallVector dstIndices; + for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { + uint64_t dim = toStoredDim(encSrc, i); + srcIndices.push_back(args[dim]); + } + translateIndicesArray(rewriter, loc, op.getReassociationIndices(), + srcIndices, srcSizes, dstSizes, dstIndices); + rewriter.create(loc, args.back(), cooBuffer, dstIndices); + rewriter.create(loc); + }); + + rewriter.replaceOpWithNewOp(op, dstTp, cooBuffer); + + return success(); + } + LogicalResult matchAndRewrite(ReshapeOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto encDst = getSparseTensorEncoding(op.getResult().getType()); auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); // Since a pure dense expansion is very cheap (change of view), for - // a sparse2dense or dense2sparse, we can simply unfuse a sparse - // conversion from the reshape operation itself. - // All other cases are handled elsewhere. + // a sparse2dense or dense2sparse, we simply unfuse a sparse conversion + // from the reshape operation itself. + // + // All other cases are handled via rewriteSparse2SparseReshape. + if (encDst && encSrc) { - return failure(); + if (enableRT) + return failure(); + return rewriteSparse2SparseReshape(op, rewriter); } + if (encSrc) { RankedTensorType rtp = op.getSrc().getType().template cast(); @@ -307,6 +390,7 @@ op->setOperand(0, convert); return success(); } + if (encDst) { RankedTensorType rtp = op.getResult().getType().template cast(); @@ -318,8 +402,11 @@ rewriter.replaceOp(op, convert); return success(); } + return failure(); } + + bool enableRT; }; struct ConcatenateRewriter : public OpRewritePattern { @@ -431,6 +518,10 @@ } }; +template +ReshapeRewriter::ReshapeRewriter(MLIRContext *context, bool rt) + : OpRewritePattern(context), enableRT(rt) {} + } // namespace //===---------------------------------------------------------------------===// @@ -439,10 +530,11 @@ void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT) { - patterns.add, - ReshapeRewriter, ForeachRewriter>( + patterns.add( patterns.getContext()); + patterns.add, + ReshapeRewriter>(patterns.getContext(), + enableRT); // TODO: If RT not enabled, rewrite concatenate ops, etc here. if (!enableRT) patterns.add(patterns.getContext()); diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND // RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV +// RUN: mlir-opt %s --sparsification=enable-runtime-library=false --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> @@ -37,6 +38,29 @@ // CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // +// rewrite for codegen: +// +// CHECK-RWT-LABEL: func.func @sparse_expand( +// CHECK-RWT-SAME: %[[S:.*]]: +// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]] +// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref +// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref +// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] { +// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[DI0:.*]] = arith.divui %[[SI]], %[[C10]] : index +// CHECK-RWT: %[[DI1:.*]] = arith.remui %[[SI]], %[[C10]] : index +// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]] +// CHECK-RWT: } +// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]] +// CHECK-RWT: return %[[T]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> { %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix> @@ -76,6 +100,37 @@ // CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // +// rewrite for codegen: +// +// CHECK-RWT-LABEL: func.func @sparse_collapse( +// CHECK-RWT-SAME: %[[S:.*]]: +// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor() +// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index} +// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index} +// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]] +// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref +// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref +// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] { +// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index +// CHECK-RWT: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref +// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] { +// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref +// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref +// CHECK-RWT: %[[T:.*]] = arith.muli %[[SI0]], %[[C10]] : index +// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index +// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]] +// CHECK-RWT } +// CHECK-RWT: } +// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]] +// CHECK-RWT: return %[[T]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> +// func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector> @@ -120,6 +175,35 @@ // CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // +// rewrite for codegen: +// +// CHECK-RWT-LABEL: func.func @dynamic_sparse_expand( +// CHECK-RWT-SAME: %[[S:.*]]: +// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[SD:.*]] = tensor.dim %[[S]], %[[C0]] +// CHECK-RWT: %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index +// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]]) +// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]] +// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref +// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref +// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] { +// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[T1:.*]] = arith.muli %[[DD0]], %[[C10]] : index +// CHECK-RWT: %[[T2:.*]] = arith.divui %[[T1]], %[[DD0]] : index +// CHECK-RWT: %[[DI0:.*]] = arith.divui %[[SI]], %[[T2]] : index +// CHECK-RWT: %[[T3:.*]] = arith.remui %[[SI]], %[[T2]] : index +// CHECK-RWT: %[[T4:.*]] = arith.divui %[[T2]], %[[C10]] : index +// CHECK-RWT: %[[DI1:.*]] = arith.divui %[[T3]], %[[T4]] : index +// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]] +// CHECK-RWT: } +// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]] +// CHECK-RWT: return %[[T]] : tensor> +// func.func @dynamic_sparse_expand(%arg0: tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor into tensor @@ -163,6 +247,42 @@ // CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // +// rewrite for codegen: +// +// CHECK-RWT-LABEL: func.func @dynamic_sparse_collapse( +// CHECK-RWT-SAME: %[[S:.*]]: +// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-RWT: %[[SD1:.*]] = tensor.dim %[[S]], %[[C1]] +// CHECK-RWT: %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index +// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]]) +// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index} +// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index} +// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index} +// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]] +// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref +// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref +// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] { +// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index +// CHECK-RWT: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref +// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref +// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] { +// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref +// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref +// CHECK-RWT: %[[T1:.*]] = arith.divui %[[DD0]], %[[C10]] : index +// CHECK-RWT: %[[T2:.*]] = arith.muli %[[SI0]], %[[T1]] : index +// CHECK-RWT: %[[T3:.*]] = arith.divui %[[T1]], %[[SD1]] : index +// CHECK-RWT: %[[T4:.*]] = arith.muli %[[SI1]], %[[T3]] : index +// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T2]], %[[T4]] : index +// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]] +// CHECK-RWT } +// CHECK-RWT: } +// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]] +// CHECK-RWT: return %[[T]] : tensor> +// func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<10x?xf64, #SparseMatrix> into tensor