diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -14,15 +14,69 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" include "mlir/IR/TensorEncoding.td" -//===----------------------------------------------------------------------===// -// Sparse Tensor Type Encoding Attribute -//===----------------------------------------------------------------------===// - // All of the Tensor attributes will extend this class. class SparseTensor_Attr traits = []> : AttrDef; +//===----------------------------------------------------------------------===// +// Sparse Tensor Dimension Slice Attribute. +//===----------------------------------------------------------------------===// + +def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> { + let mnemonic = "slice"; + + let description = [{ + An attribute to encoding slice information of a sparse tensor on a particular + dimension (a tuple of offset, size, stride). + }]; + + let parameters = ( + ins + "int64_t" : $offset, + "int64_t" : $size, + "int64_t" : $stride + ); + + let extraClassDeclaration = [{ + /// Special value for dynamic offset/size/stride. + static constexpr int64_t kDynamic = -1; + + static bool isDynamic(int64_t v) { + return v == kDynamic; + } + + std::optional getStaticOffset() const { + if (isDynamic(getOffset())) + return std::nullopt; + return static_cast(getOffset()); + }; + + std::optional getStaticStride() const { + if (isDynamic(getStride())) + return std::nullopt; + return static_cast(getStride()); + } + + std::optional getStaticSize() const { + if (isDynamic(getSize())) + return std::nullopt; + return static_cast(getSize()); + } + + bool isCompletelyDynamic() const { + return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize()); + }; + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Sparse Tensor Type Encoding Attribute. +//===----------------------------------------------------------------------===// + // Sparse tensor encoding attribute. def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", [ DeclareAttrInterfaceMethods ] > { @@ -103,6 +157,9 @@ choices are `8`, `16`, `32`, `64`, or, the default, `0` to indicate a native bit width. + - An optional array of SparseTensorDimSliceAttr, which specifies how the sparse + tensor is partitioned on each level. + Examples: ```mlir @@ -142,6 +199,15 @@ higherOrdering = affine_map<(i, j)[c] -> (c * 4 * i, i, j)> }> ... tensor ... + + // CSR slice (offset = 0, size = 4, stride = 1 on the first dimension; + // offset = 0, size = 8, and a dynamic stride on the second dimension). + #CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (0, 4, 1), (0, 8, ?) ] + }> + ... tensor ... + ``` }]; @@ -160,9 +226,29 @@ // The required bit width for pointer storage. "unsigned":$pointerBitWidth, // The required bit width for index storage. - "unsigned":$indexBitWidth + "unsigned":$indexBitWidth, + // A dimension level type for each dimension of the tensor type. + ArrayRefParameter< + "::mlir::sparse_tensor::SparseTensorDimSliceAttr", + "per dimension slice metadata" + >: $dimSlices ); + let builders = [ + AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$dimLevelType, + "AffineMap":$dimOrdering, + "AffineMap":$higherOrdering, + "unsigned":$pointerBitWidth, + "unsigned":$indexBitWidth), [{ + return $_get($_ctxt, dimLevelType, + dimOrdering, + higherOrdering, + pointerBitWidth, + indexBitWidth, + ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{}); + }]> + ]; + let extraClassDeclaration = [{ /// Returns the type for pointer storage based on pointerBitWidth Type getPointerType() const; @@ -179,6 +265,17 @@ /// Return true if the encoding has an identity dimension ordering. bool hasIdDimOrdering() const; + + bool isSlice() const { + return !getDimSlices().empty(); + } + + std::optional getStaticDimSliceOffset(unsigned dim) const; + std::optional getStaticDimSliceSize(unsigned dim) const; + std::optional getStaticDimSliceStride(unsigned dim) const; + std::optional getStaticLvlSliceOffset(unsigned lvl) const; + std::optional getStaticLvlSliceSize(unsigned lvl) const; + std::optional getStaticLvlSliceStride(unsigned lvl) const; }]; let genVerifyDecl = 1; @@ -186,7 +283,7 @@ } //===----------------------------------------------------------------------===// -// Sparse Tensor Storage Specifier Enum Attribute +// Sparse Tensor Storage Specifier Enum Attribute. //===----------------------------------------------------------------------===// // The C++ enum for Storage Specifier kind. @@ -209,7 +306,7 @@ } //===----------------------------------------------------------------------===// -// Sparse Tensor Traits +// Sparse Tensor Traits. //===----------------------------------------------------------------------===// def IsSparseTensorPred diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -45,6 +45,62 @@ } } +void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { + printer << "("; + printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?"); + printer << ", "; + printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?"); + printer << ", "; + printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?"); + printer << ")"; +} + +static ParseResult parseOptionalStaticSlice(int64_t &result, + AsmParser &parser) { + auto parseResult = parser.parseOptionalInteger(result); + if (parseResult.has_value()) { + if (parseResult.value().succeeded() && result < 0) { + parser.emitError( + parser.getCurrentLocation(), + "expect positive value or ? for slice offset/size/stride"); + return failure(); + } + return parseResult.value(); + } + + // Else, and '?' which represented dynamic slice + result = SparseTensorDimSliceAttr::kDynamic; + return parser.parseQuestion(); +} + +Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { + int64_t offset = -1, size = -1, stride = -1; + + if (failed(parser.parseLParen()) || + failed(parseOptionalStaticSlice(offset, parser)) || + failed(parser.parseComma()) || + failed(parseOptionalStaticSlice(size, parser)) || + failed(parser.parseComma()) || + failed(parseOptionalStaticSlice(stride, parser)) || + failed(parser.parseRParen())) + return {}; + + return parser.getChecked(parser.getContext(), + offset, size, stride); +} + +LogicalResult +SparseTensorDimSliceAttr::verify(function_ref emitError, + int64_t offset, int64_t size, int64_t stride) { + if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) && + (size == SparseTensorDimSliceAttr::kDynamic || size > 0) && + (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) { + return success(); + } + return emitError() + << "expect positive value or ? for slice offset/size/stride"; +} + Type SparseTensorEncodingAttr::getPointerType() const { unsigned ptrWidth = getPointerBitWidth(); Type indexType = IndexType::get(getContext()); @@ -71,24 +127,70 @@ return !getDimOrdering() || getDimOrdering().isIdentity(); } +std::optional +SparseTensorEncodingAttr::getStaticDimSliceOffset(unsigned dim) const { + return getDimSlices()[dim].getStaticOffset(); +} + +std::optional +SparseTensorEncodingAttr::getStaticDimSliceSize(unsigned dim) const { + return getDimSlices()[dim].getStaticSize(); +} + +std::optional +SparseTensorEncodingAttr::getStaticDimSliceStride(unsigned dim) const { + return getDimSlices()[dim].getStaticStride(); +} + +std::optional +SparseTensorEncodingAttr::getStaticLvlSliceOffset(unsigned lvl) const { + return getStaticDimSliceOffset(toOrigDim(*this, lvl)); +} + +std::optional +SparseTensorEncodingAttr::getStaticLvlSliceSize(unsigned lvl) const { + return getStaticDimSliceSize(toOrigDim(*this, lvl)); +} + +std::optional +SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const { + return getStaticDimSliceStride(toOrigDim(*this, lvl)); +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { - if (failed(parser.parseLess())) - return {}; - // Parse the data as a dictionary. - DictionaryAttr dict; - if (failed(parser.parseAttribute(dict))) - return {}; - if (failed(parser.parseGreater())) - return {}; +#define RETURN_ON_FAIL(stmt) \ + if (failed(stmt)) { \ + return {}; \ + } + + RETURN_ON_FAIL(parser.parseLess()) + RETURN_ON_FAIL(parser.parseLBrace()) + // Process the data from the parsed dictionary value into struct-like data. SmallVector dlt; + SmallVector slices; AffineMap dimOrd = {}; AffineMap higherOrd = {}; unsigned ptr = 0; unsigned ind = 0; - for (const NamedAttribute &attr : dict) { - if (attr.getName() == "dimLevelType") { - auto arrayAttr = attr.getValue().dyn_cast(); + + StringRef attrName; + // Exactly 6 keys. + SmallVector keys = {"dimLevelType", "dimOrdering", + "higherOrdering", "pointerBitWidth", + "indexBitWidth", "slice"}; + while (succeeded(parser.parseOptionalKeyword(&attrName))) { + if (!llvm::is_contained(keys, attrName)) { + parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; + return {}; + } + + // Consume the `=` after keys + RETURN_ON_FAIL(parser.parseEqual()) + if (attrName == "dimLevelType") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)); + auto arrayAttr = attr.dyn_cast(); if (!arrayAttr) { parser.emitError(parser.getNameLoc(), "expected an array for dimension level types"); @@ -127,47 +229,80 @@ return {}; } } - } else if (attr.getName() == "dimOrdering") { - auto affineAttr = attr.getValue().dyn_cast(); + } else if (attrName == "dimOrdering") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto affineAttr = attr.dyn_cast(); if (!affineAttr) { parser.emitError(parser.getNameLoc(), "expected an affine map for dimension ordering"); return {}; } dimOrd = affineAttr.getValue(); - } else if (attr.getName() == "higherOrdering") { - auto affineAttr = attr.getValue().dyn_cast(); + } else if (attrName == "higherOrdering") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto affineAttr = attr.dyn_cast(); if (!affineAttr) { parser.emitError(parser.getNameLoc(), "expected an affine map for higher ordering"); return {}; } higherOrd = affineAttr.getValue(); - } else if (attr.getName() == "pointerBitWidth") { - auto intAttr = attr.getValue().dyn_cast(); + } else if (attrName == "pointerBitWidth") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto intAttr = attr.dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral pointer bitwidth"); return {}; } ptr = intAttr.getInt(); - } else if (attr.getName() == "indexBitWidth") { - auto intAttr = attr.getValue().dyn_cast(); + } else if (attrName == "indexBitWidth") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto intAttr = attr.dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral index bitwidth"); return {}; } ind = intAttr.getInt(); - } else { - parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.getName().strref(); - return {}; + } else if (attrName == "slice") { + RETURN_ON_FAIL(parser.parseLSquare()) + // Dispatches to DimSliceAttr to skip mnemonic + bool finished = false; + while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) { + auto sliceAttr = attr.cast(); + slices.push_back(sliceAttr); + if (parser.parseOptionalComma().failed()) { + finished = true; + break; + } + } + // Wrong when parsing slices + if (!finished) + return {}; + RETURN_ON_FAIL(parser.parseRSquare()) } + + // Only the last item can omit the comma + if (parser.parseOptionalComma().failed()) + break; } + + RETURN_ON_FAIL(parser.parseRBrace()) + RETURN_ON_FAIL(parser.parseGreater()) +#undef RETURN_ON_FAIL + // Construct struct-like storage for attribute. return parser.getChecked( - parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind); + parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind, slices); } void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { @@ -188,14 +323,26 @@ printer << ", pointerBitWidth = " << getPointerBitWidth(); if (getIndexBitWidth()) printer << ", indexBitWidth = " << getIndexBitWidth(); + if (!getDimSlices().empty()) { + printer << ", slice = [ "; + llvm::interleaveComma(getDimSlices(), printer, + [&](SparseTensorDimSliceAttr attr) { + // Calls SparseTensorDimSliceAttr::print directly to + // skip mnemonic. + attr.print(printer); + }); + printer << " ]"; + } + printer << " }>"; } LogicalResult SparseTensorEncodingAttr::verify( function_ref emitError, ArrayRef dimLevelType, AffineMap dimOrdering, - AffineMap higherOrdering, unsigned pointerBitWidth, - unsigned indexBitWidth) { + AffineMap higherOrdering, unsigned pointerBitWidth, unsigned indexBitWidth, + ArrayRef dimSlices) { + // TODO: verify slices if (!acceptBitWidth(pointerBitWidth)) return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; if (!acceptBitWidth(indexBitWidth)) @@ -217,6 +364,11 @@ return emitError() << "unexpected mismatch in higher ordering and " "dimension level types size"; } + if (!dimSlices.empty() && dimSlices.size() != dimLevelType.size()) { + return emitError() << "unexpected mismatch in dimension slices and " + "dimension level type size"; + } + return success(); } @@ -226,7 +378,7 @@ // Check structural integrity. if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(), getPointerBitWidth(), - getIndexBitWidth()))) + getIndexBitWidth(), getDimSlices()))) return failure(); // Check integrity with tensor type specifics. Dimension ordering is optional, // but we always should have dimension level types for the full rank. diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir @@ -68,3 +68,10 @@ #a = #sparse_tensor.encoding<{dimLevelType = [ "compressed", "compressed", "dense", "dense" ], dimOrdering = affine_map<(ii, jj, i, j) -> (ii, jj, i, j)>, higherOrdering = affine_map<(i, j) -> (j, i)>}> // expected-error {{unexpected higher ordering mapping from 2 to 2}} func.func private @tensor_invalid_key(%arg0: tensor<10x60xf32, #a>) -> () +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value or ? for slice offset/size/stride}} +}> +func.func private @sparse_slice(tensor) diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -88,3 +88,35 @@ // CHECK-SAME: tensor (d0 * (s0 * 4), d0, d1)> }>> func.func private @sparse_ell(tensor) +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func private @sparse_slice( +// CHECK-SAME: tensor> +func.func private @sparse_slice(tensor) + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func private @sparse_slice( +// CHECK-SAME: tensor> +func.func private @sparse_slice(tensor) + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, ?, 1), (?, 4, 2) ] +}> + +// CHECK-LABEL: func private @sparse_slice( +// CHECK-SAME: tensor> +func.func private @sparse_slice(tensor)