diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -142,6 +142,10 @@ }]; let constructor = "mlir::createSparseTensorCodegenPass()"; let dependentDialects = [ + "arith::ArithmeticDialect", + "bufferization::BufferizationDialect", + "memref::MemRefDialect", + "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -33,14 +33,24 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Reorders stored dimension to logical dimension. -static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) { +/// Reorders stored dimension to original dimension. +static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) { auto order = enc.getDimOrdering(); if (order) { assert(order.isPermutation()); - return order.getDimPosition(d); + return order.getDimPosition(i); } - return d; + return i; +} + +/// Reorders original dimension to stored dimension. +static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) { + auto order = enc.getDimOrdering(); + if (order) { + assert(order.isPermutation()); + return order.getPermutedPosition(i); + } + return i; } /// Maps a sparse tensor type to the appropriate compounded buffers. @@ -63,14 +73,13 @@ // single compound type with the following fields: // // struct { - // ; if dynamic shape: - // memref dimSize ; size in each dimension + // memref dimSizes ; size in each dimension // ; per-dimension d: // ; if dense: // // ; if compresed: - // memref indices-d ; indices for sparse dim d // memref pointers-d ; pointers for sparse dim d + // memref indices-d ; indices for sparse dim d // ; if singleton: // memref indices-d ; indices for singleton dim d // memref values ; values @@ -81,12 +90,11 @@ unsigned rank = rType.getShape().size(); SmallVector fields; // The dimSizes array. - if (!rType.hasStaticShape()) - fields.push_back(MemRefType::get({rank}, indexType)); + fields.push_back(MemRefType::get({rank}, indexType)); // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { // Get the original dimension (ro) for the current stored dimension (r). - unsigned ro = reorder(enc, r); + unsigned ro = toOrig(enc, r); // Dimension level types apply in order to the reordered dimension. // As a result, the compound type can be constructed directly in the given // order. Clients of this type know what field is what from the sparse @@ -103,8 +111,8 @@ case SparseTensorEncodingAttr::DimLevelType::CompressedNu: case SparseTensorEncodingAttr::DimLevelType::CompressedNo: case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: - fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType)); + fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); allDense = false; linear = 1; break; @@ -128,6 +136,64 @@ return TupleType::get(context, fields); } +// Returns field index for pointers (d), indices (d) for set field. +// Otherwise the values field index is returned. +static unsigned getFieldIndex(Type type, unsigned ptrDim = -1, + unsigned idxDim = -1) { + auto enc = getSparseTensorEncoding(type); + assert(enc); + RankedTensorType rType = type.cast(); + unsigned field = 1; // start at DimSizes; + unsigned ptr = 0; + unsigned idx = 0; + for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { + switch (enc.getDimLevelType()[r]) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + break; // no fields + case SparseTensorEncodingAttr::DimLevelType::Compressed: + case SparseTensorEncodingAttr::DimLevelType::CompressedNu: + case SparseTensorEncodingAttr::DimLevelType::CompressedNo: + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + if (ptr++ == ptrDim) + return field; + field++; + if (idx++ == idxDim) + return field; + field++; + break; + case SparseTensorEncodingAttr::DimLevelType::Singleton: + case SparseTensorEncodingAttr::DimLevelType::SingletonNu: + case SparseTensorEncodingAttr::DimLevelType::SingletonNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + if (idx++ == idxDim) + return field; + field++; + break; + } + } + return field; // values pos +} + +/// Returns field type in tuple at given position. +static Type getFieldType(Value tuple, unsigned field) { + return tuple.getType().cast().getType(field); +} + +/// Returns integral constant, if defined. +static Optional getConstantInt(Value val) { + if (auto constantOp = val.getDefiningOp()) + return constantOp.getValue().cast().getInt(); + return {}; +} + +/// Creates tuple get operation. +static Value createTupleGet(OpBuilder &builder, Location loc, Type resType, + Value source, unsigned field) { + Type indexType = builder.getIndexType(); + return builder.create(loc, resType, source, + builder.getIntegerAttr(indexType, field)); +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -151,26 +217,86 @@ LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - Type type = op.getSource().getType(); // Only rewrite annotated DimOp with constant index. - auto enc = getSparseTensorEncoding(type); + auto enc = getSparseTensorEncoding(op.getSource().getType()); if (!enc) return failure(); - Optional index = op.getConstantIndex(); + Optional index = getConstantInt(adaptor.getIndex()); if (!index) return failure(); - // Access into static shape can query original type directly. + // Access into static dimension can query original type directly. // Note that this is typically already done by DimOp's folding. - RankedTensorType rType = type.cast(); - if (rType.hasStaticShape()) { - rewriter.replaceOp( - op, constantIndex(rewriter, loc, rType.getShape()[*index])); + Location loc = op->getLoc(); + auto shape = op.getSource().getType().cast().getShape(); + if (!ShapedType::isDynamic(shape[*index])) { + rewriter.replaceOp(op, constantIndex(rewriter, loc, shape[*index])); return success(); } - // Any other query can consult the dimSize array. - // TODO: this needs tuple access - return failure(); + // Any other query can consult the dimSizes array at field 0 using, + // accounting for the reordering applied to the sparse storage. + Value tuple = adaptor.getSource(); + Value dimSizes = + createTupleGet(rewriter, loc, getFieldType(tuple, 0), tuple, 0); + rewriter.replaceOpWithNewOp( + op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index))); + return success(); + } +}; + +/// Sparse conversion rule for pointer accesses. +class SparseToPointersConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Optional index = getConstantInt(adaptor.getOperands()[1]); + if (!index) + return failure(); + // Replace the requested pointer access with corresponding field. + Location loc = op->getLoc(); + Value tuple = adaptor.getTensor(); + unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index); + rewriter.replaceOp( + op, createTupleGet(rewriter, loc, getFieldType(tuple, i), tuple, i)); + return success(); + } +}; + +/// Sparse conversion rule for index accesses. +class SparseToIndicesConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Optional index = getConstantInt(adaptor.getOperands()[1]); + if (!index) + return failure(); + // Replace the requested indices access with corresponding field. + Location loc = op->getLoc(); + Value tuple = adaptor.getTensor(); + unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index); + rewriter.replaceOp( + op, createTupleGet(rewriter, loc, getFieldType(tuple, i), tuple, i)); + return success(); + } +}; + +/// Sparse conversion rule for value accesses. +class SparseToValuesConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Replace the requested values access with corresponding field. + Location loc = op->getLoc(); + Value tuple = adaptor.getTensor(); + unsigned i = getFieldIndex(op.getTensor().getType()); + rewriter.replaceOp( + op, createTupleGet(rewriter, loc, getFieldType(tuple, i), tuple, i)); + return success(); } }; @@ -193,6 +319,7 @@ /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -156,8 +156,9 @@ RewritePatternSet patterns(ctx); SparseTensorTypeToBufferConverter converter; ConversionTarget target(*ctx); - // Everything in the sparse dialect must go! + // Almost everything in the sparse dialect must go! target.addIllegalDialect(); + target.addLegalOp(); // All dynamic rules below accept new function, call, return. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -32,14 +32,12 @@ #Dense3D = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], - indexBitWidth = 64, - pointerBitWidth = 32, - dimOrdering = affine_map<(i,j,k) -> (k, i,j)> + dimOrdering = affine_map<(i, j, k) -> (k, i, j)> }> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -> tuple, memref, memref, memref> -// CHECK: return %[[A]] : tuple, memref, memref, memref> +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -> tuple, memref, memref, memref> +// CHECK: return %[[A]] : tuple, memref, memref, memref> func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } @@ -51,28 +49,29 @@ } // CHECK-LABEL: func @sparse_row( -// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) func.func @sparse_row(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_csr( -// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) func.func @sparse_csr(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_dcsr( -// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) func.func @sparse_dcsr(%arg0: tensor) { return } // -// Just a linearized array in the end. Dim op is statically known. +// Querying for dimension 1 in the tensor type can immediately +// fold using the original static dimension sizes. // // CHECK-LABEL: func @sparse_dense_3d( -// CHECK-SAME: %[[A:.*]]: tuple>) -> index +// CHECK-SAME: %[[A:.*]]: tuple, memref<6000xf64>>) -> index { // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -80,3 +79,49 @@ %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D> return %0 : index } + +// +// Querying for dimension 1 in the tensor type needs to be permuted +// into querying for dimension 2 in the stored sparse tensor scheme, +// since the latter honors the dimOrdering. +// +// CHECK-LABEL: func @sparse_dense_3d_dyn( +// CHECK-SAME: %[[A:.*]]: tuple, memref>) -> index { +// CHECK: %[[C:.*]] = arith.constant 2 : index +// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple, memref> to memref<3xindex> +// CHECK: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex> +// CHECK: return %[[L]] : index +func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { + %c = arith.constant 1 : index + %0 = tensor.dim %arg0, %c : tensor + return %0 : index +} + +// CHECK-LABEL: func @sparse_pointers_dcsr( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) +// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple, memref, memref, memref, memref, memref> to memref +// CHECK: return %[[F]] : memref +func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { + %c = arith.constant 1 : index + %0 = sparse_tensor.pointers %arg0, %c : tensor to memref + return %0 : memref +} + +// CHECK-LABEL: func @sparse_indices_dcsr( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) +// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][4] : tuple, memref, memref, memref, memref, memref> to memref +// CHECK: return %[[F]] : memref +func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { + %c = arith.constant 1 : index + %0 = sparse_tensor.indices %arg0, %c : tensor to memref + return %0 : memref +} + +// CHECK-LABEL: func @sparse_values_dcsr( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) +// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][5] : tuple, memref, memref, memref, memref, memref> to memref +// CHECK: return %[[F]] : memref +func.func @sparse_values_dcsr(%arg0: tensor) -> memref { + %0 = sparse_tensor.values %arg0 : tensor to memref + return %0 : memref +}