diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -33,6 +33,16 @@ // Helper methods. //===----------------------------------------------------------------------===// +/// Reorders stored dimension to logical dimension. +static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) { + auto order = enc.getDimOrdering(); + if (order) { + assert(order.isPermutation()); + return order.getDimPosition(d); + } + return d; +} + /// Maps a sparse tensor type to the appropriate compounded buffers. static Optional convertSparseTensorType(Type type) { auto enc = getSparseTensorEncoding(type); @@ -47,12 +57,14 @@ Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; Type eltType = rType.getElementType(); + ArrayRef shape = rType.getShape(); // // Sparse tensor storage for rank-dimensional tensor is organized as a // single compound type with the following fields: // // struct { - // memref dimSize ; size in each dimension + // ; if dynamic shape: + // memref dimSize ; size in each dimension // ; per-dimension d: // ; if dense: // @@ -61,23 +73,31 @@ // memref pointers-d ; pointers for sparse dim d // ; if singleton: // memref indices-d ; indices for singleton dim d - // memref values ; values + // memref values ; values // }; // - // TODO: fill in the ? when statically known - // - // TODO: emit dimSizes when not needed (e.g. all-dense) - // + int64_t linear = 1; + bool allDense = true; unsigned rank = rType.getShape().size(); SmallVector fields; - fields.push_back(MemRefType::get({rank}, indexType)); + // The dimSizes array. + if (!rType.hasStaticShape()) + fields.push_back(MemRefType::get({rank}, indexType)); + // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { + // Get the original dimension (ro) for the current stored dimension (r). + unsigned ro = reorder(enc, r); // Dimension level types apply in order to the reordered dimension. // As a result, the compound type can be constructed directly in the given // order. Clients of this type know what field is what from the sparse // tensor type. switch (enc.getDimLevelType()[r]) { case SparseTensorEncodingAttr::DimLevelType::Dense: + // Linearize the size of consecutive dense dimensions. + if (ShapedType::isDynamic(shape[ro]) || ShapedType::isDynamic(linear)) + linear = ShapedType::kDynamicSize; + else + linear *= shape[ro]; break; case SparseTensorEncodingAttr::DimLevelType::Compressed: case SparseTensorEncodingAttr::DimLevelType::CompressedNu: @@ -85,16 +105,23 @@ case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType)); + allDense = false; + linear = 1; break; case SparseTensorEncodingAttr::DimLevelType::Singleton: case SparseTensorEncodingAttr::DimLevelType::SingletonNu: case SparseTensorEncodingAttr::DimLevelType::SingletonNo: case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); + allDense = false; + linear = 1; break; } } - fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType)); + // The values array. + int64_t nnz = + (rType.hasStaticShape() && allDense) ? linear : ShapedType::kDynamicSize; + fields.push_back(MemRefType::get({nnz}, eltType)); // Sparse tensor storage (temporarily) lives in a tuple. This allows a // simple 1:1 type conversion during codegen. A subsequent pass uses // a 1:N type conversion to expand the tuple into its fields. @@ -102,10 +129,10 @@ } //===----------------------------------------------------------------------===// -// Conversion rules. +// Codegen rules. //===----------------------------------------------------------------------===// -/// Sparse conversion rule for returns. +/// Sparse codegen rule for returns. class SparseReturnConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -117,6 +144,36 @@ } }; +/// Sparse codegen rule for dimension accesses. +class SparseDimOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Type type = op.getSource().getType(); + // Only rewrite annotated DimOp with constant index. + auto enc = getSparseTensorEncoding(type); + if (!enc) + return failure(); + Optional index = op.getConstantIndex(); + if (!index) + return failure(); + // Access into static shape can query original type directly. + // Note that this is typically already done by DimOp's folding. + RankedTensorType rType = type.cast(); + if (rType.hasStaticShape()) { + rewriter.replaceOp( + op, constantIndex(rewriter, loc, rType.getShape()[*index])); + return success(); + } + // Any other query can consult the dimSize array. + // TODO: this needs tuple access + return failure(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -136,5 +193,6 @@ /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -158,8 +158,7 @@ ConversionTarget target(*ctx); // Everything in the sparse dialect must go! target.addIllegalDialect(); - // All dynamic rules below accept new function, call, return, and various - // tensor and bufferization operations as legal output of the rewriting. + // All dynamic rules below accept new function, call, return. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); @@ -169,6 +168,10 @@ target.addDynamicallyLegalOp([&](func::ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); + // Legal dialects may occur in generated code. + target.addLegalDialect(); // Populate with rules and apply rewriting rules. populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -6,7 +6,7 @@ pointerBitWidth = 32 }> -#Dense = #sparse_tensor.encoding<{ +#Dense2D = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ], indexBitWidth = 64, pointerBitWidth = 32 @@ -30,6 +30,13 @@ pointerBitWidth = 32 }> +#Dense3D = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "dense", "dense" ], + indexBitWidth = 64, + pointerBitWidth = 32, + dimOrdering = affine_map<(i,j,k) -> (k, i,j)> +}> + // CHECK-LABEL: func @sparse_nop( // CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -> tuple, memref, memref, memref> // CHECK: return %[[A]] : tuple, memref, memref, memref> @@ -37,9 +44,9 @@ return %arg0 : tensor } -// CHECK-LABEL: func @sparse_dense( +// CHECK-LABEL: func @sparse_dense_2d( // CHECK-SAME: %[[A:.*]]: tuple, memref>) -func.func @sparse_dense(%arg0: tensor) { +func.func @sparse_dense_2d(%arg0: tensor) { return } @@ -60,3 +67,16 @@ func.func @sparse_dcsr(%arg0: tensor) { return } + +// +// Just a linearized array in the end. Dim op is statically known. +// +// CHECK-LABEL: func @sparse_dense_3d( +// CHECK-SAME: %[[A:.*]]: tuple>) -> index +// CHECK: %[[C:.*]] = arith.constant 20 : index +// CHECK: return %[[C]] : index +func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { + %c = arith.constant 1 : index + %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D> + return %0 : index +}