diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -76,32 +76,14 @@ let extraClassDeclaration = [{ /// Special value for dynamic offset/size/stride. static constexpr int64_t kDynamic = -1; - - static bool isDynamic(int64_t v) { - return v == kDynamic; - } - - std::optional getStaticOffset() const { - if (isDynamic(getOffset())) - return std::nullopt; - return static_cast(getOffset()); - }; - - std::optional getStaticStride() const { - if (isDynamic(getStride())) - return std::nullopt; - return static_cast(getStride()); - } - - std::optional getStaticSize() const { - if (isDynamic(getSize())) - return std::nullopt; - return static_cast(getSize()); - } - - bool isCompletelyDynamic() const { - return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize()); - }; + static constexpr bool isDynamic(int64_t v) { return v == kDynamic; } + static std::optional getStatic(int64_t v); + static std::string getStaticString(int64_t v); + + std::optional getStaticOffset() const; + std::optional getStaticStride() const; + std::optional getStaticSize() const; + bool isCompletelyDynamic() const; }]; let genVerifyDecl = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -31,6 +31,23 @@ using namespace mlir; using namespace mlir::sparse_tensor; +//===----------------------------------------------------------------------===// +// Additional convenience methods. +//===----------------------------------------------------------------------===// + +static constexpr bool acceptBitWidth(unsigned bitWidth) { + switch (bitWidth) { + case 0: + case 8: + case 16: + case 32: + case 64: + return true; + default: + return false; + } +} + //===----------------------------------------------------------------------===// // StorageLayout //===----------------------------------------------------------------------===// @@ -166,26 +183,39 @@ // TensorDialect Attribute Methods. //===----------------------------------------------------------------------===// -static bool acceptBitWidth(unsigned bitWidth) { - switch (bitWidth) { - case 0: - case 8: - case 16: - case 32: - case 64: - return true; - default: - return false; - } +std::optional SparseTensorDimSliceAttr::getStatic(int64_t v) { + return isDynamic(v) ? std::nullopt + : std::make_optional(static_cast(v)); +} + +std::optional SparseTensorDimSliceAttr::getStaticOffset() const { + return getStatic(getOffset()); +} + +std::optional SparseTensorDimSliceAttr::getStaticStride() const { + return getStatic(getStride()); +} + +std::optional SparseTensorDimSliceAttr::getStaticSize() const { + return getStatic(getSize()); +} + +bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { + return isDynamic(getOffset()) && isDynamic(getStride()) && + isDynamic(getSize()); +} + +std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { + return isDynamic(v) ? "?" : std::to_string(v); } void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { printer << "("; - printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?"); + printer << getStaticString(getOffset()); printer << ", "; - printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?"); + printer << getStaticString(getSize()); printer << ", "; - printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?"); + printer << getStaticString(getStride()); printer << ")"; } @@ -208,7 +238,7 @@ } Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { - int64_t offset = -1, size = -1, stride = -1; + int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; if (failed(parser.parseLParen()) || failed(parseOptionalStaticSlice(offset, parser)) || @@ -226,13 +256,13 @@ LogicalResult SparseTensorDimSliceAttr::verify(function_ref emitError, int64_t offset, int64_t size, int64_t stride) { - if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) && - (size == SparseTensorDimSliceAttr::kDynamic || size > 0) && - (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) { - return success(); - } - return emitError() - << "expect positive value or ? for slice offset/size/stride"; + if (!isDynamic(offset) && offset < 0) + return emitError() << "expect non-negative value or ? for slice offset"; + if (!isDynamic(size) && size <= 0) + return emitError() << "expect positive value or ? for slice size"; + if (!isDynamic(stride) && stride <= 0) + return emitError() << "expect positive value or ? for slice stride"; + return success(); } Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,