diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -101,16 +101,18 @@ << "expect positive value or ? for slice offset/size/stride"; } +static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) { + if (bitwidth) + return IntegerType::get(ctx, bitwidth); + return IndexType::get(ctx); +} + Type SparseTensorEncodingAttr::getPointerType() const { - unsigned ptrWidth = getPointerBitWidth(); - Type indexType = IndexType::get(getContext()); - return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType; + return getIntegerOrIndexType(getContext(), getPointerBitWidth()); } Type SparseTensorEncodingAttr::getIndexType() const { - unsigned idxWidth = getIndexBitWidth(); - Type indexType = IndexType::get(getContext()); - return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType; + return getIntegerOrIndexType(getContext(), getIndexBitWidth()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const { @@ -157,11 +159,30 @@ return getStaticDimSliceStride(toOrigDim(*this, lvl)); } +const static DimLevelType validDLTs[] = { + DimLevelType::Dense, DimLevelType::Compressed, + DimLevelType::CompressedNu, DimLevelType::CompressedNo, + DimLevelType::CompressedNuNo, DimLevelType::Singleton, + DimLevelType::SingletonNu, DimLevelType::SingletonNo, + DimLevelType::SingletonNuNo}; + +static std::optional parseDLT(StringRef str) { + for (DimLevelType dlt : validDLTs) + if (str == toMLIRString(dlt)) + return dlt; + return std::nullopt; +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { #define RETURN_ON_FAIL(stmt) \ if (failed(stmt)) { \ return {}; \ } +#define ERROR_IF(COND, MSG) \ + if (COND) { \ + parser.emitError(parser.getNameLoc(), MSG); \ + return {}; \ + } RETURN_ON_FAIL(parser.parseLess()) RETURN_ON_FAIL(parser.parseLBrace()) @@ -191,37 +212,13 @@ Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)); auto arrayAttr = attr.dyn_cast(); - if (!arrayAttr) { - parser.emitError(parser.getNameLoc(), - "expected an array for dimension level types"); - return {}; - } + ERROR_IF(!arrayAttr, "expected an array for dimension level types") for (auto i : arrayAttr) { auto strAttr = i.dyn_cast(); - if (!strAttr) { - parser.emitError(parser.getNameLoc(), - "expected a string value in dimension level types"); - return {}; - } + ERROR_IF(!strAttr, "expected a string value in dimension level types") auto strVal = strAttr.getValue(); - if (strVal == "dense") { - dlt.push_back(DimLevelType::Dense); - } else if (strVal == "compressed") { - dlt.push_back(DimLevelType::Compressed); - } else if (strVal == "compressed-nu") { - dlt.push_back(DimLevelType::CompressedNu); - } else if (strVal == "compressed-no") { - dlt.push_back(DimLevelType::CompressedNo); - } else if (strVal == "compressed-nu-no") { - dlt.push_back(DimLevelType::CompressedNuNo); - } else if (strVal == "singleton") { - dlt.push_back(DimLevelType::Singleton); - } else if (strVal == "singleton-nu") { - dlt.push_back(DimLevelType::SingletonNu); - } else if (strVal == "singleton-no") { - dlt.push_back(DimLevelType::SingletonNo); - } else if (strVal == "singleton-nu-no") { - dlt.push_back(DimLevelType::SingletonNuNo); + if (auto optDLT = parseDLT(strVal)) { + dlt.push_back(optDLT.value()); } else { parser.emitError(parser.getNameLoc(), "unexpected dimension level type: ") @@ -232,46 +229,26 @@ } else if (attrName == "dimOrdering") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto affineAttr = attr.dyn_cast(); - if (!affineAttr) { - parser.emitError(parser.getNameLoc(), - "expected an affine map for dimension ordering"); - return {}; - } + ERROR_IF(!affineAttr, "expected an affine map for dimension ordering") dimOrd = affineAttr.getValue(); } else if (attrName == "higherOrdering") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto affineAttr = attr.dyn_cast(); - if (!affineAttr) { - parser.emitError(parser.getNameLoc(), - "expected an affine map for higher ordering"); - return {}; - } + ERROR_IF(!affineAttr, "expected an affine map for higher ordering") higherOrd = affineAttr.getValue(); } else if (attrName == "pointerBitWidth") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto intAttr = attr.dyn_cast(); - if (!intAttr) { - parser.emitError(parser.getNameLoc(), - "expected an integral pointer bitwidth"); - return {}; - } + ERROR_IF(!intAttr, "expected an integral pointer bitwidth") ptr = intAttr.getInt(); } else if (attrName == "indexBitWidth") { Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) - auto intAttr = attr.dyn_cast(); - if (!intAttr) { - parser.emitError(parser.getNameLoc(), - "expected an integral index bitwidth"); - return {}; - } + ERROR_IF(!intAttr, "expected an integral index bitwidth") ind = intAttr.getInt(); } else if (attrName == "slice") { RETURN_ON_FAIL(parser.parseLSquare()) @@ -298,6 +275,7 @@ RETURN_ON_FAIL(parser.parseRBrace()) RETURN_ON_FAIL(parser.parseGreater()) +#undef ERROR_IF #undef RETURN_ON_FAIL // Construct struct-like storage for attribute. @@ -367,18 +345,21 @@ return emitError() << "unexpected mismatch in dimension slices and " "dimension level type size"; } - return success(); } +#define RETURN_FAILURE_IF_FAILED(X) \ + if (failed(X)) { \ + return failure(); \ + } + LogicalResult SparseTensorEncodingAttr::verifyEncoding( ArrayRef shape, Type elementType, function_ref emitError) const { // Check structural integrity. - if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), - getHigherOrdering(), getPointerBitWidth(), - getIndexBitWidth(), getDimSlices()))) - return failure(); + RETURN_FAILURE_IF_FAILED(verify( + emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(), + getPointerBitWidth(), getIndexBitWidth(), getDimSlices())) // Check integrity with tensor type specifics. Dimension ordering is optional, // but we always should have dimension level types for the full rank. unsigned size = shape.size(); @@ -435,23 +416,17 @@ bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) { SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp); - if (!enc) - return false; - - return isCOOType(enc, 0, /*isUnique=*/true); + return enc && isCOOType(enc, 0, /*isUnique=*/true); } unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) { - unsigned rank = enc.getDimLevelType().size(); - if (rank <= 1) - return rank; - + const unsigned rank = enc.getDimLevelType().size(); // We only consider COO region with at least two dimensions for the purpose // of AOS storage optimization. - for (unsigned r = 0; r < rank - 1; r++) { - if (isCOOType(enc, r, /*isUnique=*/false)) - return r; - } + if (rank > 1) + for (unsigned r = 0; r < rank - 1; r++) + if (isCOOType(enc, r, /*isUnique=*/false)) + return r; return rank; } @@ -541,10 +516,8 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, std::optional dim) const { - std::optional intDim; - if (dim) - intDim = dim.value().getZExtValue(); - return getFieldType(kind, intDim); + return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue()) + : std::nullopt); } //===----------------------------------------------------------------------===// @@ -552,17 +525,12 @@ //===----------------------------------------------------------------------===// static LogicalResult isInBounds(uint64_t dim, Value tensor) { - uint64_t rank = tensor.getType().cast().getRank(); - if (dim >= rank) - return failure(); - return success(); // in bounds + return success(dim < tensor.getType().cast().getRank()); } static LogicalResult isMatchingWidth(Value result, unsigned width) { - Type etp = result.getType().cast().getElementType(); - if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) - return success(); - return failure(); + const Type etp = result.getType().cast().getElementType(); + return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); } static LogicalResult verifySparsifierGetterSetter( @@ -663,11 +631,8 @@ } LogicalResult GetStorageSpecifierOp::verify() { - if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(), - getSpecifier(), getOperation()))) { - return failure(); - } - + RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter( + getSpecifierKind(), getDim(), getSpecifier(), getOperation())) // Checks the result type if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) != getResult().getType()) { @@ -692,11 +657,8 @@ } LogicalResult SetStorageSpecifierOp::verify() { - if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(), - getSpecifier(), getOperation()))) { - return failure(); - } - + RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter( + getSpecifierKind(), getDim(), getSpecifier(), getOperation())) // Checks the input type if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) != getValue().getType()) { @@ -748,59 +710,45 @@ // Check correct number of block arguments and return type for each // non-empty region. - LogicalResult regionResult = success(); if (!overlap.empty()) { - regionResult = verifyNumBlockArgs( - this, overlap, "overlap", TypeRange{leftType, rightType}, outputType); - if (failed(regionResult)) - return regionResult; + RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs( + this, overlap, "overlap", TypeRange{leftType, rightType}, outputType)) } if (!left.empty()) { - regionResult = - verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType); - if (failed(regionResult)) - return regionResult; + RETURN_FAILURE_IF_FAILED( + verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType)) } else if (getLeftIdentity()) { if (leftType != outputType) return emitError("left=identity requires first argument to have the same " "type as the output"); } if (!right.empty()) { - regionResult = verifyNumBlockArgs(this, right, "right", - TypeRange{rightType}, outputType); - if (failed(regionResult)) - return regionResult; + RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs( + this, right, "right", TypeRange{rightType}, outputType)) } else if (getRightIdentity()) { if (rightType != outputType) return emitError("right=identity requires second argument to have the " "same type as the output"); } - return success(); } LogicalResult UnaryOp::verify() { Type inputType = getX().getType(); Type outputType = getOutput().getType(); - LogicalResult regionResult = success(); // Check correct number of block arguments and return type for each // non-empty region. Region &present = getPresentRegion(); if (!present.empty()) { - regionResult = verifyNumBlockArgs(this, present, "present", - TypeRange{inputType}, outputType); - if (failed(regionResult)) - return regionResult; + RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs( + this, present, "present", TypeRange{inputType}, outputType)) } Region &absent = getAbsentRegion(); if (!absent.empty()) { - regionResult = - verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType); - if (failed(regionResult)) - return regionResult; + RETURN_FAILURE_IF_FAILED( + verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType)) } - return success(); } @@ -880,8 +828,7 @@ } LogicalResult PushBackOp::verify() { - Value n = getN(); - if (n) { + if (Value n = getN()) { auto nValue = dyn_cast_or_null(n.getDefiningOp()); if (nValue && nValue.value() < 1) return emitOpError("n must be not less than 1"); @@ -972,32 +919,21 @@ LogicalResult ReduceOp::verify() { Type inputType = getX().getType(); - LogicalResult regionResult = success(); - // Check correct number of block arguments and return type. Region &formula = getRegion(); - regionResult = verifyNumBlockArgs(this, formula, "reduce", - TypeRange{inputType, inputType}, inputType); - if (failed(regionResult)) - return regionResult; - + RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs( + this, formula, "reduce", TypeRange{inputType, inputType}, inputType)) return success(); } LogicalResult SelectOp::verify() { Builder b(getContext()); - Type inputType = getX().getType(); Type boolType = b.getI1Type(); - LogicalResult regionResult = success(); - // Check correct number of block arguments and return type. Region &formula = getRegion(); - regionResult = verifyNumBlockArgs(this, formula, "select", - TypeRange{inputType}, boolType); - if (failed(regionResult)) - return regionResult; - + RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select", + TypeRange{inputType}, boolType)) return success(); } @@ -1025,15 +961,8 @@ } return success(); }; - - LogicalResult result = checkTypes(getXs()); - if (failed(result)) - return result; - - if (n) - return checkTypes(getYs(), false); - - return success(); + RETURN_FAILURE_IF_FAILED(checkTypes(getXs())) + return n ? checkTypes(getYs(), false) : success(); } LogicalResult SortCooOp::verify() { @@ -1084,6 +1013,8 @@ "reduce, select or foreach"); } +#undef RETURN_FAILURE_IF_FAILED + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===//