diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -27,9 +27,43 @@ namespace mlir { namespace sparse_tensor { + /// Convenience method to get a sparse encoding attribute from a type. /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); + +// +// Dimension level types. +// + +bool isDenseDim(SparseTensorEncodingAttr::DimLevelType dltp); +bool isCompressedDim(SparseTensorEncodingAttr::DimLevelType dltp); +bool isSingletonDim(SparseTensorEncodingAttr::DimLevelType dltp); + +/// Convenience method to test for dense dimension (0 <= d < rank). +bool isDenseDim(RankedTensorType type, uint64_t d); + +/// Convenience method to test for compressed dimension (0 <= d < rank). +bool isCompressedDim(RankedTensorType type, uint64_t d); + +/// Convenience method to test for singleton dimension (0 <= d < rank). +bool isSingletonDim(RankedTensorType type, uint64_t d); + +// +// Dimension level properties. +// + +bool isOrderedDim(SparseTensorEncodingAttr::DimLevelType dltp); +bool isUniqueDim(SparseTensorEncodingAttr::DimLevelType dltp); + +/// Convenience method to test for ordered property in the +/// given dimension (0 <= d < rank). +bool isOrderedDim(RankedTensorType type, uint64_t d); + +/// Convenience method to test for unique property in the +/// given dimension (0 <= d < rank). +bool isUniqueDim(RankedTensorType type, uint64_t d); + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -216,6 +216,10 @@ return success(); } +//===----------------------------------------------------------------------===// +// Convenience Methods. +//===----------------------------------------------------------------------===// + SparseTensorEncodingAttr mlir::sparse_tensor::getSparseTensorEncoding(Type type) { if (auto ttp = type.dyn_cast()) @@ -223,6 +227,98 @@ return nullptr; } +bool mlir::sparse_tensor::isDenseDim( + SparseTensorEncodingAttr::DimLevelType dltp) { + return dltp == SparseTensorEncodingAttr::DimLevelType::Dense; +} + +bool mlir::sparse_tensor::isCompressedDim( + SparseTensorEncodingAttr::DimLevelType dltp) { + switch (dltp) { + case SparseTensorEncodingAttr::DimLevelType::Compressed: + case SparseTensorEncodingAttr::DimLevelType::CompressedNu: + case SparseTensorEncodingAttr::DimLevelType::CompressedNo: + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + return true; + default: + return false; + } +} + +bool mlir::sparse_tensor::isSingletonDim( + SparseTensorEncodingAttr::DimLevelType dltp) { + switch (dltp) { + case SparseTensorEncodingAttr::DimLevelType::Singleton: + case SparseTensorEncodingAttr::DimLevelType::SingletonNu: + case SparseTensorEncodingAttr::DimLevelType::SingletonNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + return true; + default: + return false; + } +} + +bool mlir::sparse_tensor::isDenseDim(RankedTensorType type, uint64_t d) { + assert(d < static_cast(type.getRank())); + if (auto enc = getSparseTensorEncoding(type)) + return isDenseDim(enc.getDimLevelType()[d]); + return true; // unannotated tensor is dense +} + +bool mlir::sparse_tensor::isCompressedDim(RankedTensorType type, uint64_t d) { + assert(d < static_cast(type.getRank())); + if (auto enc = getSparseTensorEncoding(type)) + return isCompressedDim(enc.getDimLevelType()[d]); + return false; // unannotated tensor is dense +} + +bool mlir::sparse_tensor::isSingletonDim(RankedTensorType type, uint64_t d) { + assert(d < static_cast(type.getRank())); + if (auto enc = getSparseTensorEncoding(type)) + return isSingletonDim(enc.getDimLevelType()[d]); + return false; // unannotated tensor is dense +} + +bool mlir::sparse_tensor::isOrderedDim( + SparseTensorEncodingAttr::DimLevelType dltp) { + switch (dltp) { + case SparseTensorEncodingAttr::DimLevelType::CompressedNo: + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + return false; + default: + return true; + } +} + +bool mlir::sparse_tensor::isUniqueDim( + SparseTensorEncodingAttr::DimLevelType dltp) { + switch (dltp) { + case SparseTensorEncodingAttr::DimLevelType::CompressedNu: + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNu: + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + return false; + default: + return true; + } +} + +bool mlir::sparse_tensor::isOrderedDim(RankedTensorType type, uint64_t d) { + assert(d < static_cast(type.getRank())); + if (auto enc = getSparseTensorEncoding(type)) + return isOrderedDim(enc.getDimLevelType()[d]); + return true; // unannotated tensor is dense (and thus ordered) +} + +bool mlir::sparse_tensor::isUniqueDim(RankedTensorType type, uint64_t d) { + assert(d < static_cast(type.getRank())); + if (auto enc = getSparseTensorEncoding(type)) + return isUniqueDim(enc.getDimLevelType()[d]); + return true; // unannotated tensor is dense (and thus unique) +} + //===----------------------------------------------------------------------===// // TensorDialect Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -103,37 +103,28 @@ /// Returns field index of sparse tensor type for pointers/indices, when set. static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { - auto enc = getSparseTensorEncoding(type); - assert(enc); + assert(getSparseTensorEncoding(type)); RankedTensorType rType = type.cast(); unsigned field = 2; // start past sizes unsigned ptr = 0; unsigned idx = 0; for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { - switch (enc.getDimLevelType()[r]) { - case SparseTensorEncodingAttr::DimLevelType::Dense: - break; // no fields - case SparseTensorEncodingAttr::DimLevelType::Compressed: - case SparseTensorEncodingAttr::DimLevelType::CompressedNu: - case SparseTensorEncodingAttr::DimLevelType::CompressedNo: - case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + if (isCompressedDim(rType, r)) { if (ptr++ == ptrDim) return field; field++; if (idx++ == idxDim) return field; field++; - break; - case SparseTensorEncodingAttr::DimLevelType::Singleton: - case SparseTensorEncodingAttr::DimLevelType::SingletonNu: - case SparseTensorEncodingAttr::DimLevelType::SingletonNo: - case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + } else if (isSingletonDim(rType, r)) { if (idx++ == idxDim) return field; field++; - break; + } else { + assert(isDenseDim(rType, r)); // no fields } } + assert(ptrDim == -1u && idxDim == -1u); return field + 1; // return values field index } @@ -176,7 +167,7 @@ // The dimSizes array. fields.push_back(MemRefType::get({rank}, indexType)); // The memSizes array. - unsigned lastField = getFieldIndex(type, -1, -1); + unsigned lastField = getFieldIndex(type, -1u, -1u); fields.push_back(MemRefType::get({lastField - 2}, indexType)); // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { @@ -184,22 +175,13 @@ // As a result, the compound type can be constructed directly in the given // order. Clients of this type know what field is what from the sparse // tensor type. - switch (enc.getDimLevelType()[r]) { - case SparseTensorEncodingAttr::DimLevelType::Dense: - break; // no fields - case SparseTensorEncodingAttr::DimLevelType::Compressed: - case SparseTensorEncodingAttr::DimLevelType::CompressedNu: - case SparseTensorEncodingAttr::DimLevelType::CompressedNo: - case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + if (isCompressedDim(rType, r)) { fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType)); fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); - break; - case SparseTensorEncodingAttr::DimLevelType::Singleton: - case SparseTensorEncodingAttr::DimLevelType::SingletonNu: - case SparseTensorEncodingAttr::DimLevelType::SingletonNo: - case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + } else if (isSingletonDim(rType, r)) { fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); - break; + } else { + assert(isDenseDim(rType, r)); // no fields } } // The values array. @@ -254,7 +236,7 @@ builder.create(loc, MemRefType::get({rank}, indexType)); fields.push_back(dimSizes); // The sizes array. - unsigned lastField = getFieldIndex(type, -1, -1); + unsigned lastField = getFieldIndex(type, -1u, -1u); Value memSizes = builder.create( loc, MemRefType::get({lastField - 2}, indexType)); fields.push_back(memSizes); @@ -265,25 +247,16 @@ builder.create(loc, sizes[ro], dimSizes, constantIndex(builder, loc, r)); linear = builder.create(loc, linear, sizes[ro]); - // Allocate fiels. - switch (enc.getDimLevelType()[r]) { - case SparseTensorEncodingAttr::DimLevelType::Dense: - break; // no fields - case SparseTensorEncodingAttr::DimLevelType::Compressed: - case SparseTensorEncodingAttr::DimLevelType::CompressedNu: - case SparseTensorEncodingAttr::DimLevelType::CompressedNo: - case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + // Allocate fields. + if (isCompressedDim(rType, r)) { fields.push_back(createAllocation(builder, loc, ptrType, heuristic)); fields.push_back(createAllocation(builder, loc, idxType, heuristic)); allDense = false; - break; - case SparseTensorEncodingAttr::DimLevelType::Singleton: - case SparseTensorEncodingAttr::DimLevelType::SingletonNu: - case SparseTensorEncodingAttr::DimLevelType::SingletonNo: - case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + } else if (isSingletonDim(rType, r)) { fields.push_back(createAllocation(builder, loc, idxType, heuristic)); allDense = false; - break; + } else { + assert(isDenseDim(rType, r)); // no fields } } // The values array. For all-dense, the full length is required. @@ -507,7 +480,8 @@ matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - ShapedType srcType = op.getTensor().getType().cast(); + RankedTensorType srcType = + op.getTensor().getType().cast(); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); @@ -561,17 +535,18 @@ matchAndRewrite(CompressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - ShapedType srcType = op.getTensor().getType().cast(); - Type eltType = srcType.getElementType(); + RankedTensorType dstType = + op.getTensor().getType().cast(); + Type eltType = dstType.getElementType(); Value values = adaptor.getValues(); Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); Value count = adaptor.getCount(); - - // - // TODO: need to implement "std::sort(added, added + count);" for ordered - // - + // If the innermost dimension is ordered, we need to sort the indices + // in the "added" array prior to applying the compression. + unsigned rank = dstType.getShape().size(); + if (isOrderedDim(dstType, rank - 1)) + rewriter.create(loc, count, ValueRange{added}, ValueRange{}); // While performing the insertions, we also need to reset the elements // of the values/filled-switch by only iterating over the set elements, // to ensure that the runtime complexity remains proportional to the @@ -699,7 +674,7 @@ static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, ToPointersOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1); + return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u); } }; @@ -712,7 +687,7 @@ static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, ToIndicesOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/dim); + return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -156,8 +156,9 @@ RewritePatternSet patterns(ctx); SparseTensorTypeToBufferConverter converter; ConversionTarget target(*ctx); - // Everything in the sparse dialect must go! + // Most ops in the sparse dialect must go! target.addIllegalDialect(); + target.addLegalOp(); // All dynamic rules below accept new function, call, return, and various // tensor and bufferization operations as legal output of the rewriting // provided that all sparse tensor types have been fully rewritten. diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -24,6 +24,10 @@ pointerBitWidth = 32 }> +#UCSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed-no" ] +}> + #CSC = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(i, j) -> (j, i)> @@ -363,7 +367,7 @@ // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// TODO: sort +// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] { // CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref // TODO: insert @@ -385,6 +389,43 @@ return } +// CHECK-LABEL: func @sparse_compression_unordered( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: index, +// CHECK-SAME: %[[A9:.*9]]: index) +// CHECK-DAG: %[[B0:.*]] = arith.constant false +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NOT: sparse_tensor.sort +// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] { +// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref +// TODO: insert +// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref +// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref +// CHECK-NEXT: } +// CHECK-DAG: memref.dealloc %[[A5]] : memref +// CHECK-DAG: memref.dealloc %[[A6]] : memref +// CHECK-DAG: memref.dealloc %[[A7]] : memref +// CHECK: return +func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, + %values: memref, + %filled: memref, + %added: memref, + %count: index, + %i: index) { + sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i] + : memref, memref, memref, tensor<8x8xf64, #UCSR> + return +} + // CHECK-LABEL: func @sparse_push_back( // CHECK-SAME: %[[A:.*]]: memref, // CHECK-SAME: %[[B:.*]]: memref,