diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -14,15 +14,68 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" include "mlir/IR/TensorEncoding.td" -//===----------------------------------------------------------------------===// -// Sparse Tensor Type Encoding Attribute -//===----------------------------------------------------------------------===// - // All of the Tensor attributes will extend this class. class SparseTensor_Attr traits = []> : AttrDef; +//===----------------------------------------------------------------------===// +// Sparse Tensor Dimension Slice Attribute +//===----------------------------------------------------------------------===// + +def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> { + let mnemonic = "slice"; + + // TODO + let description = [{ + abc + }]; + + let parameters = ( + ins + "int64_t" : $offset, + "int64_t" : $size, + "int64_t" : $stride + ); + + let extraClassDeclaration = [{ + /// Special value for dynamic offset/stride/size. + static constexpr int64_t kDynamic = -1; + + static bool isDynamic(int64_t v) { + return v == kDynamic; + } + + std::optional getStaticOffset() const { + if (isDynamic(getOffset())) + return std::nullopt; + return static_cast(getOffset()); + }; + + std::optional getStaticStride() const { + if (isDynamic(getStride())) + return std::nullopt; + return static_cast(getStride()); + } + + std::optional getStaticSize() const { + if (isDynamic(getSize())) + return std::nullopt; + return static_cast(getSize()); + } + + bool isCompletelyDynamic() const { + return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize()); + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Sparse Tensor Type Encoding Attribute +//===----------------------------------------------------------------------===// + // Sparse tensor encoding attribute. def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", [ DeclareAttrInterfaceMethods ] > { @@ -146,6 +199,7 @@ }]; // Data in sparse tensor encoding. + let parameters = ( ins // A dimension level type for each dimension of the tensor type. @@ -160,9 +214,29 @@ // The required bit width for pointer storage. "unsigned":$pointerBitWidth, // The required bit width for index storage. - "unsigned":$indexBitWidth + "unsigned":$indexBitWidth, + // A dimension level type for each dimension of the tensor type. + ArrayRefParameter< + "::mlir::sparse_tensor::SparseTensorDimSliceAttr", + "per dimension slice metadata" + >: $dimSlices ); + let builders = [ + AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$dimLevelType, + "AffineMap":$dimOrdering, + "AffineMap":$higherOrdering, + "unsigned":$pointerBitWidth, + "unsigned":$indexBitWidth), [{ + return $_get($_ctxt, dimLevelType, + dimOrdering, + higherOrdering, + pointerBitWidth, + indexBitWidth, + ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{}); + }]> + ]; + let extraClassDeclaration = [{ /// Returns the type for pointer storage based on pointerBitWidth Type getPointerType() const; @@ -179,6 +253,17 @@ /// Return true if the encoding has an identity dimension ordering. bool hasIdDimOrdering() const; + + bool isSlice() const { + return !getDimSlices().empty(); + } + + std::optional getStaticDimSliceOffset(unsigned dim) const; + std::optional getStaticDimSliceSize(unsigned dim) const; + std::optional getStaticDimSliceStride(unsigned dim) const; + std::optional getStaticLvlSliceOffset(unsigned lvl) const; + std::optional getStaticLvlSliceSize(unsigned lvl) const; + std::optional getStaticLvlSliceStride(unsigned lvl) const; }]; let genVerifyDecl = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -45,6 +45,63 @@ } } +void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { + printer << "("; + printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?"); + printer << ", "; + printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?"); + printer << ", "; + printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?"); + printer << ")"; +} + +static ParseResult parseOptionalStaticSlice(int64_t &result, + AsmParser &parser) { + auto parseResult = parser.parseOptionalInteger(result); + if (parseResult.has_value()) { + if (parseResult.value().succeeded() && result < 0) { + parser.emitError( + parser.getCurrentLocation(), + "expect positive value for static slice offset/size/stride"); + return failure(); + } + return parseResult.value(); + } + + // Else, and '?' which represented dynamic slice + result = SparseTensorDimSliceAttr::kDynamic; + return parser.parseQuestion(); +} + +Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { + if (failed(parser.parseLParen())) + return {}; + + int64_t offset = -1; + if (failed(parseOptionalStaticSlice(offset, parser))) + return {}; + + if (parser.parseComma().failed()) + return {}; + + int64_t size = -1; + if (failed(parseOptionalStaticSlice(size, parser))) + return {}; + + if (parser.parseComma().failed()) + return {}; + + int64_t stride = -1; + if (failed(parseOptionalStaticSlice(stride, parser))) + return {}; + + if (failed(parser.parseRParen())) + return {}; + + return parser.getChecked(parser.getContext(), + offset, size, stride); +} + Type SparseTensorEncodingAttr::getPointerType() const { unsigned ptrWidth = getPointerBitWidth(); Type indexType = IndexType::get(getContext()); @@ -71,24 +128,70 @@ return !getDimOrdering() || getDimOrdering().isIdentity(); } +std::optional +SparseTensorEncodingAttr::getStaticDimSliceOffset(unsigned dim) const { + return getDimSlices()[dim].getStaticOffset(); +} + +std::optional +SparseTensorEncodingAttr::getStaticDimSliceSize(unsigned dim) const { + return getDimSlices()[dim].getStaticSize(); +} + +std::optional +SparseTensorEncodingAttr::getStaticDimSliceStride(unsigned dim) const { + return getDimSlices()[dim].getStaticStride(); +} + +std::optional +SparseTensorEncodingAttr::getStaticLvlSliceOffset(unsigned lvl) const { + return getStaticDimSliceOffset(toOrigDim(*this, lvl)); +} + +std::optional +SparseTensorEncodingAttr::getStaticLvlSliceSize(unsigned lvl) const { + return getStaticDimSliceSize(toOrigDim(*this, lvl)); +} + +std::optional +SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const { + return getStaticDimSliceStride(toOrigDim(*this, lvl)); +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { - if (failed(parser.parseLess())) - return {}; - // Parse the data as a dictionary. - DictionaryAttr dict; - if (failed(parser.parseAttribute(dict))) - return {}; - if (failed(parser.parseGreater())) - return {}; +#define RETURN_ON_FAIL(stmt) \ + if (failed(stmt)) { \ + return {}; \ + } + + RETURN_ON_FAIL(parser.parseLess()) + RETURN_ON_FAIL(parser.parseLBrace()) + // Process the data from the parsed dictionary value into struct-like data. SmallVector dlt; + SmallVector slices; AffineMap dimOrd = {}; AffineMap higherOrd = {}; unsigned ptr = 0; unsigned ind = 0; - for (const NamedAttribute &attr : dict) { - if (attr.getName() == "dimLevelType") { - auto arrayAttr = attr.getValue().dyn_cast(); + + StringRef attrName; + // Exactly 6 keys. + SmallVector keys = {"dimLevelType", "dimOrdering", + "higherOrdering", "pointerBitWidth", + "indexBitWidth", "slice"}; + while (succeeded(parser.parseOptionalKeyword(&attrName))) { + if (!llvm::is_contained(keys, attrName)) { + parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; + return {}; + } + + // Consume the `=` after keys + RETURN_ON_FAIL(parser.parseEqual()) + if (attrName == "dimLevelType") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)); + auto arrayAttr = attr.dyn_cast(); if (!arrayAttr) { parser.emitError(parser.getNameLoc(), "expected an array for dimension level types"); @@ -127,47 +230,80 @@ return {}; } } - } else if (attr.getName() == "dimOrdering") { - auto affineAttr = attr.getValue().dyn_cast(); + } else if (attrName == "dimOrdering") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto affineAttr = attr.dyn_cast(); if (!affineAttr) { parser.emitError(parser.getNameLoc(), "expected an affine map for dimension ordering"); return {}; } dimOrd = affineAttr.getValue(); - } else if (attr.getName() == "higherOrdering") { - auto affineAttr = attr.getValue().dyn_cast(); + } else if (attrName == "higherOrdering") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto affineAttr = attr.dyn_cast(); if (!affineAttr) { parser.emitError(parser.getNameLoc(), "expected an affine map for higher ordering"); return {}; } higherOrd = affineAttr.getValue(); - } else if (attr.getName() == "pointerBitWidth") { - auto intAttr = attr.getValue().dyn_cast(); + } else if (attrName == "pointerBitWidth") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto intAttr = attr.dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral pointer bitwidth"); return {}; } ptr = intAttr.getInt(); - } else if (attr.getName() == "indexBitWidth") { - auto intAttr = attr.getValue().dyn_cast(); + } else if (attrName == "indexBitWidth") { + Attribute attr; + RETURN_ON_FAIL(parser.parseAttribute(attr)) + + auto intAttr = attr.dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral index bitwidth"); return {}; } ind = intAttr.getInt(); - } else { - parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.getName().strref(); - return {}; + } else if (attrName == "slice") { + RETURN_ON_FAIL(parser.parseLSquare()) + // Dispatches to DimSliceAttr to skip mnemonic + bool finished = false; + while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) { + auto sliceAttr = attr.cast(); + slices.push_back(sliceAttr); + if (parser.parseOptionalComma().failed()) { + finished = true; + break; + } + } + // Wrong when parsing slices + if (!finished) + return {}; + RETURN_ON_FAIL(parser.parseRSquare()) } + + // Only the last item can omit the comma + if (parser.parseOptionalComma().failed()) + break; } + + RETURN_ON_FAIL(parser.parseRBrace()) + RETURN_ON_FAIL(parser.parseGreater()) +#undef RETURN_ON_FAIL + // Construct struct-like storage for attribute. return parser.getChecked( - parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind); + parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind, slices); } void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { @@ -188,14 +324,26 @@ printer << ", pointerBitWidth = " << getPointerBitWidth(); if (getIndexBitWidth()) printer << ", indexBitWidth = " << getIndexBitWidth(); + if (!getDimSlices().empty()) { + printer << ", slice = [ "; + llvm::interleaveComma(getDimSlices(), printer, + [&](SparseTensorDimSliceAttr attr) { + // Calls SparseTensorDimSliceAttr::print directly to + // skip mnemonic. + attr.print(printer); + }); + printer << " ]"; + } + printer << " }>"; } LogicalResult SparseTensorEncodingAttr::verify( function_ref emitError, ArrayRef dimLevelType, AffineMap dimOrdering, - AffineMap higherOrdering, unsigned pointerBitWidth, - unsigned indexBitWidth) { + AffineMap higherOrdering, unsigned pointerBitWidth, unsigned indexBitWidth, + ArrayRef dimSlices) { + // TODO: verify slices if (!acceptBitWidth(pointerBitWidth)) return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; if (!acceptBitWidth(indexBitWidth)) @@ -226,7 +374,7 @@ // Check structural integrity. if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(), getPointerBitWidth(), - getIndexBitWidth()))) + getIndexBitWidth(), getDimSlices()))) return failure(); // Check integrity with tensor type specifics. Dimension ordering is optional, // but we always should have dimension level types for the full rank. diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir @@ -68,3 +68,10 @@ #a = #sparse_tensor.encoding<{dimLevelType = [ "compressed", "compressed", "dense", "dense" ], dimOrdering = affine_map<(ii, jj, i, j) -> (ii, jj, i, j)>, higherOrdering = affine_map<(i, j) -> (j, i)>}> // expected-error {{unexpected higher ordering mapping from 2 to 2}} func.func private @tensor_invalid_key(%arg0: tensor<10x60xf32, #a>) -> () +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value for static slice offset/size/stride}} +}> +func.func private @sparse_slice(tensor) diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -88,3 +88,35 @@ // CHECK-SAME: tensor (d0 * (s0 * 4), d0, d1)> }>> func.func private @sparse_ell(tensor) +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func private @sparse_slice( +// CHECK-SAME: tensor> +func.func private @sparse_slice(tensor) + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func private @sparse_slice( +// CHECK-SAME: tensor> +func.func private @sparse_slice(tensor) + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, ?, 1), (?, 4, 2) ] +}> + +// CHECK-LABEL: func private @sparse_slice( +// CHECK-SAME: tensor> +func.func private @sparse_slice(tensor)