diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -33,14 +33,72 @@ // Helper methods. //===----------------------------------------------------------------------===// -/// Maps each sparse tensor type to the appropriate buffer. -static Optional convertSparseTensorTypes(Type type) { - if (getSparseTensorEncoding(type) != nullptr) { - // TODO: this is just a dummy rule to get the ball rolling.... - RankedTensorType rTp = type.cast(); - return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType()); +/// Maps a sparse tensor type to the appropriate compounded buffers. +static Optional convertSparseTensorType(Type type) { + auto enc = getSparseTensorEncoding(type); + if (!enc) + return llvm::None; + // Construct the basic types. + auto context = type.getContext(); + unsigned idxWidth = enc.getIndexBitWidth(); + unsigned ptrWidth = enc.getPointerBitWidth(); + RankedTensorType rType = type.cast(); + Type indexType = IndexType::get(context); + Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; + Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; + Type eltType = rType.getElementType(); + // + // Sparse tensor storage for rank-dimensional tensor is organized as a + // single compound type with the following fields: + // + // struct { + // memref dimSize ; size in each dimension + // ; per-dimension d: + // ; if dense: + // + // ; if compresed: + // memref indices-d ; indices for sparse dim d + // memref pointers-d ; pointers for sparse dim d + // ; if singleton: + // memref indices-d ; indices for singleton dim d + // memref values ; values + // }; + // + // TODO: fill in the ? when statically known + // + // TODO: emit dimSizes when not needed (e.g. all-dense) + // + unsigned rank = rType.getShape().size(); + SmallVector fields; + fields.push_back(MemRefType::get({rank}, indexType)); + for (unsigned r = 0; r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. Clients of this type know what field is what from the sparse + // tensor type. + switch (enc.getDimLevelType()[r]) { + case SparseTensorEncodingAttr::DimLevelType::Dense: + break; + case SparseTensorEncodingAttr::DimLevelType::Compressed: + case SparseTensorEncodingAttr::DimLevelType::CompressedNu: + case SparseTensorEncodingAttr::DimLevelType::CompressedNo: + case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: + fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); + fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType)); + break; + case SparseTensorEncodingAttr::DimLevelType::Singleton: + case SparseTensorEncodingAttr::DimLevelType::SingletonNu: + case SparseTensorEncodingAttr::DimLevelType::SingletonNo: + case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: + fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType)); + break; + } } - return llvm::None; + fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType)); + // Sparse tensor storage (temporarily) lives in a tuple. This allows a + // simple 1:1 type conversion during codegen. A subsequent pass uses + // a 1:N type conversion to expand the tuple into its fields. + return TupleType::get(context, fields); } //===----------------------------------------------------------------------===// @@ -67,7 +125,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { addConversion([](Type type) { return type; }); - addConversion(convertSparseTensorTypes); + addConversion(convertSparseTensorType); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,14 +1,62 @@ // RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ - dimLevelType = ["compressed"] + dimLevelType = [ "compressed" ], + indexBitWidth = 64, + pointerBitWidth = 32 }> -// TODO: just a dummy memref rewriting to get the ball rolling.... +#Dense = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "dense" ], + indexBitWidth = 64, + pointerBitWidth = 32 +}> + +#Row = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "dense" ], + indexBitWidth = 64, + pointerBitWidth = 32 +}> + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + indexBitWidth = 64, + pointerBitWidth = 32 +}> + +#DCSR = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + indexBitWidth = 64, + pointerBitWidth = 32 +}> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A:.*]]: memref) -> memref { -// CHECK: return %[[A]] : memref +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -> tuple, memref, memref, memref> +// CHECK: return %[[A]] : tuple, memref, memref, memref> func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } + +// CHECK-LABEL: func @sparse_dense( +// CHECK-SAME: %[[A:.*]]: tuple, memref>) +func.func @sparse_dense(%arg0: tensor) { + return +} + +// CHECK-LABEL: func @sparse_row( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) +func.func @sparse_row(%arg0: tensor) { + return +} + +// CHECK-LABEL: func @sparse_csr( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) +func.func @sparse_csr(%arg0: tensor) { + return +} + +// CHECK-LABEL: func @sparse_dcsr( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref, memref, memref>) +func.func @sparse_dcsr(%arg0: tensor) { + return +}