diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -160,10 +160,33 @@ // The required bit width for pointer storage. "unsigned":$pointerBitWidth, // The required bit width for index storage. - "unsigned":$indexBitWidth + "unsigned":$indexBitWidth, + // A dimension level type for each dimension of the tensor type. + ArrayRefParameter< + "std::tuple", + "per dimension slice metadata" + >: $dimSlices ); + let builders = [ + AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$dimLevelType, + "AffineMap":$dimOrdering, + "AffineMap":$higherOrdering, + "unsigned":$pointerBitWidth, + "unsigned":$indexBitWidth), [{ + return $_get($_ctxt, dimLevelType, + dimOrdering, + higherOrdering, + pointerBitWidth, + indexBitWidth, + ArrayRef>{}); + }]> + ]; + let extraClassDeclaration = [{ + /// + using DimSlice = std::tuple; + /// Returns the type for pointer storage based on pointerBitWidth Type getPointerType() const; @@ -179,6 +202,17 @@ /// Return true if the encoding has an identity dimension ordering. bool hasIdDimOrdering() const; + + bool isSlice() const { + return !getDimSlices().empty(); + } + + std::optional getDimSliceOffset(unsigned dim) const; + std::optional getDimSliceLength(unsigned dim) const; + std::optional getDimSliceStride(unsigned dim) const; + std::optional getLvlSliceOffset(unsigned lvl) const; + std::optional getLvlSliceLength(unsigned lvl) const; + std::optional getLvlSliceStride(unsigned lvl) const; }]; let genVerifyDecl = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -71,6 +71,45 @@ return !getDimOrdering() || getDimOrdering().isIdentity(); } +std::optional +SparseTensorEncodingAttr::getDimSliceOffset(unsigned dim) const { + auto v = std::get<0>(getDimSlices()[dim]); + if (v == -1) + return std::nullopt; + return v; +} + +std::optional +SparseTensorEncodingAttr::getDimSliceLength(unsigned dim) const { + auto v = std::get<1>(getDimSlices()[dim]); + if (v == -1) + return std::nullopt; + return v; +} + +std::optional +SparseTensorEncodingAttr::getDimSliceStride(unsigned dim) const { + auto v = std::get<2>(getDimSlices()[dim]); + if (v == -1) + return std::nullopt; + return v; +} + +std::optional +SparseTensorEncodingAttr::getLvlSliceOffset(unsigned lvl) const { + return getDimSliceOffset(toOrigDim(*this, lvl)); +} + +std::optional +SparseTensorEncodingAttr::getLvlSliceLength(unsigned lvl) const { + return getDimSliceLength(toOrigDim(*this, lvl)); +} + +std::optional +SparseTensorEncodingAttr::getLvlSliceStride(unsigned lvl) const { + return getDimSliceStride(toOrigDim(*this, lvl)); +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; @@ -82,6 +121,7 @@ return {}; // Process the data from the parsed dictionary value into struct-like data. SmallVector dlt; + SmallVector slices; AffineMap dimOrd = {}; AffineMap higherOrd = {}; unsigned ptr = 0; @@ -159,6 +199,31 @@ return {}; } ind = intAttr.getInt(); + } else if (attr.getName() == "slice") { + auto slicesAttr = attr.getValue().dyn_cast(); + if (!slicesAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for slices"); + return {}; + } + for (auto a : slicesAttr) { + auto sliceAttr = a.dyn_cast(); + if (!sliceAttr || sliceAttr.size() != 3) { + parser.emitError(parser.getNameLoc(), + "expected an array of size 3 for a slice"); + return {}; + } + auto offset = sliceAttr.getValue()[0].dyn_cast(); + auto length = sliceAttr.getValue()[1].dyn_cast(); + auto stride = sliceAttr.getValue()[2].dyn_cast(); + + if (!offset || !length || !stride) { + parser.emitError(parser.getNameLoc(), + "expected an integral slice offset/length/stride"); + return {}; + } + slices.push_back( + std::make_tuple(offset.getInt(), length.getInt(), stride.getInt())); + } } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); @@ -167,7 +232,7 @@ } // Construct struct-like storage for attribute. return parser.getChecked( - parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind); + parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind, slices); } void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { @@ -188,14 +253,26 @@ printer << ", pointerBitWidth = " << getPointerBitWidth(); if (getIndexBitWidth()) printer << ", indexBitWidth = " << getIndexBitWidth(); + if (!getDimSlices().empty()) { + printer << ", slice = [ "; + llvm::interleaveComma(getDimSlices(), printer, [&](DimSlice dimSlice) { + int offset = std::get<0>(dimSlice); + int length = std::get<1>(dimSlice); + int stride = std::get<2>(dimSlice); + printer << '[' << offset << ", " << length << ", " << stride << ']'; + }); + printer << " ]"; + } + printer << " }>"; } LogicalResult SparseTensorEncodingAttr::verify( function_ref emitError, ArrayRef dimLevelType, AffineMap dimOrdering, - AffineMap higherOrdering, unsigned pointerBitWidth, - unsigned indexBitWidth) { + AffineMap higherOrdering, unsigned pointerBitWidth, unsigned indexBitWidth, + ArrayRef dimSlices) { + // TODO: verify slices if (!acceptBitWidth(pointerBitWidth)) return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; if (!acceptBitWidth(indexBitWidth)) @@ -226,7 +303,7 @@ // Check structural integrity. if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(), getPointerBitWidth(), - getIndexBitWidth()))) + getIndexBitWidth(), getDimSlices()))) return failure(); // Check integrity with tensor type specifics. Dimension ordering is optional, // but we always should have dimension level types for the full rank.