diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -70,6 +70,12 @@ /// Return if the given 'index' refers to a valid element in this attribute. bool isValidIndex(ArrayRef index) const; + static bool isValidIndex(ShapedType type, ArrayRef index); + + /// Returns the 1-dimensional flattened row-major index from the given + /// multi-dimensional index. + uint64_t getFlattenedIndex(ArrayRef index) const; + static uint64_t getFlattenedIndex(ShapedType type, ArrayRef index); /// Returns the number of elements held by this attribute. int64_t getNumElements() const; @@ -94,11 +100,6 @@ /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr); - -protected: - /// Returns the 1 dimensional flattened row-major index from the given - /// multi-dimensional index. - uint64_t getFlattenedIndex(ArrayRef index) const; }; namespace detail { diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -791,6 +791,7 @@ public: }]; + let genVerifyDecl = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -405,25 +405,45 @@ return cast().getValue(index); } -/// Return if the given 'index' refers to a valid element in this attribute. bool ElementsAttr::isValidIndex(ArrayRef index) const { - auto type = getType(); - + return isValidIndex(getType(), index); +} +bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { // Verify that the rank of the indices matches the held type. - auto rank = type.getRank(); + int64_t rank = type.getRank(); if (rank == 0 && index.size() == 1 && index[0] == 0) return true; if (rank != static_cast(index.size())) return false; // Verify that all of the indices are within the shape dimensions. - auto shape = type.getShape(); + ArrayRef shape = type.getShape(); return llvm::all_of(llvm::seq(0, rank), [&](int i) { int64_t dim = static_cast(index[i]); return 0 <= dim && dim < shape[i]; }); } +uint64_t ElementsAttr::getFlattenedIndex(ArrayRef index) const { + return getFlattenedIndex(getType(), index); +} +uint64_t ElementsAttr::getFlattenedIndex(ShapedType type, + ArrayRef index) { + assert(isValidIndex(type, index) && "expected valid multi-dimensional index"); + + // Reduce the provided multidimensional index into a flattended 1D row-major + // index. + auto rank = type.getRank(); + auto shape = type.getShape(); + uint64_t valueIndex = 0; + uint64_t dimMultiplier = 1; + for (int i = rank - 1; i >= 0; --i) { + valueIndex += index[i] * dimMultiplier; + dimMultiplier *= shape[i]; + } + return valueIndex; +} + ElementsAttr ElementsAttr::mapValues(Type newElementType, function_ref mapping) const { @@ -446,25 +466,6 @@ OpaqueElementsAttr, SparseElementsAttr>(); } -/// Returns the 1 dimensional flattened row-major index from the given -/// multi-dimensional index. -uint64_t ElementsAttr::getFlattenedIndex(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - auto type = getType(); - - // Reduce the provided multidimensional index into a flattended 1D row-major - // index. - auto rank = type.getRank(); - auto shape = type.getShape(); - uint64_t valueIndex = 0; - uint64_t dimMultiplier = 1; - for (int i = rank - 1; i >= 0; --i) { - valueIndex += index[i] * dimMultiplier; - dimMultiplier *= shape[i]; - } - return valueIndex; -} - //===----------------------------------------------------------------------===// // DenseElementsAttr Utilities //===----------------------------------------------------------------------===// @@ -1421,6 +1422,64 @@ return flatSparseIndices; } +LogicalResult +SparseElementsAttr::verify(function_ref emitError, + ShapedType type, DenseIntElementsAttr sparseIndices, + DenseElementsAttr values) { + ShapedType valuesType = values.getType(); + if (valuesType.getRank() != 1) + return emitError() << "expected 1-d tensor for sparse element values"; + + // Verify the indices and values shape. + ShapedType indicesType = sparseIndices.getType(); + auto emitShapeError = [&]() { + return emitError() << "expected shape ([" << type.getShape() + << "]); inferred shape of indices literal ([" + << indicesType.getShape() + << "]); inferred shape of values literal ([" + << valuesType.getShape() << "])"; + }; + // Verify indices shape. + size_t rank = type.getRank(), indicesRank = indicesType.getRank(); + if (indicesRank == 2) { + if (indicesType.getDimSize(1) != rank) + return emitShapeError(); + } else if (indicesRank != 1 || rank != 1) { + return emitShapeError(); + } + // Verify the values shape. + int64_t numSparseIndices = indicesType.getDimSize(0); + if (numSparseIndices != valuesType.getDimSize(0)) + return emitShapeError(); + + // Verify that the sparse indices are within the value shape. + auto emitIndexError = [&](unsigned indexNum, ArrayRef index) { + return emitError() + << "sparse index #" << indexNum + << " is not contained within the value shape, with index=[" << index + << "], and type=" << type; + }; + + // Handle the case where the index values are a splat. + auto sparseIndexValues = sparseIndices.getValues(); + if (sparseIndices.isSplat()) { + SmallVector indices(rank, *sparseIndexValues.begin()); + if (!ElementsAttr::isValidIndex(type, indices)) + return emitIndexError(0, indices); + return success(); + } + + // Otherwise, reinterpret each index as an ArrayRef. + for (size_t i = 0, e = numSparseIndices; i != e; ++i) { + ArrayRef index(&*std::next(sparseIndexValues.begin(), i * rank), + rank); + if (!ElementsAttr::isValidIndex(type, index)) + return emitIndexError(i, index); + } + + return success(); +} + //===----------------------------------------------------------------------===// // TypeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -893,6 +893,7 @@ /// Parse a sparse elements attribute. Attribute Parser::parseSparseElementsAttr(Type attrType) { + llvm::SMLoc loc = getToken().getLoc(); consumeToken(Token::kw_sparse); if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; @@ -911,8 +912,8 @@ ShapedType indicesType = RankedTensorType::get({0, type.getRank()}, indiceEltType); ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); - return SparseElementsAttr::get( - type, DenseElementsAttr::get(indicesType, ArrayRef()), + return getChecked( + loc, type, DenseElementsAttr::get(indicesType, ArrayRef()), DenseElementsAttr::get(valuesType, ArrayRef())); } @@ -963,22 +964,6 @@ : RankedTensorType::get(valuesParser.getShape(), valuesEltType); auto values = valuesParser.getAttr(valuesLoc, valuesType); - /// Sanity check. - if (valuesType.getRank() != 1) - return (emitError("expected 1-d tensor for values"), nullptr); - - auto sameShape = (indicesType.getRank() == 1) || - (type.getRank() == indicesType.getDimSize(1)); - auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); - if (!sameShape || !sameElementNum) { - emitError() << "expected shape ([" << type.getShape() - << "]); inferred shape of indices literal ([" - << indicesType.getShape() - << "]); inferred shape of values literal ([" - << valuesType.getShape() << "])"; - return nullptr; - } - // Build the sparse elements attribute by the indices and values. - return SparseElementsAttr::get(type, indices, values); + return getChecked(loc, type, indices, values); } diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -140,6 +140,16 @@ // Type Parsing //===--------------------------------------------------------------------===// + /// Invoke the `getChecked` method of the given Attribute or Type class, using + /// the provided location to emit errors in the case of failure. Note that + /// unlike `OpBuilder::getType`, this method does not implicitly insert a + /// context parameter. + template + T getChecked(llvm::SMLoc loc, ParamsT &&...params) { + return T::getChecked([&] { return emitError(loc); }, + std::forward(params)...); + } + ParseResult parseFunctionResultTypes(SmallVectorImpl &elements); ParseResult parseTypeListNoParens(SmallVectorImpl &elements); ParseResult parseTypeListParens(SmallVectorImpl &elements); diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -193,6 +193,7 @@ /// memory-space ::= integer-literal /* | TODO: address-space-id */ /// Type Parser::parseMemRefType() { + llvm::SMLoc loc = getToken().getLoc(); consumeToken(Token::kw_memref); if (parseToken(Token::less, "expected '<' in memref type")) @@ -283,15 +284,11 @@ } } - if (isUnranked) { - return UnrankedMemRefType::getChecked( - [&]() -> InFlightDiagnostic { return emitError(); }, elementType, - memorySpace); - } + if (isUnranked) + return getChecked(loc, elementType, memorySpace); - return MemRefType::getChecked( - [&]() -> InFlightDiagnostic { return emitError(); }, dimensions, - elementType, affineMapComposition, memorySpace); + return getChecked(loc, dimensions, elementType, + affineMapComposition, memorySpace); } /// Parse any type except the function type. diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1087,19 +1087,19 @@ // CHECK: 1.000000e+00 : f32 // CHECK: 1.000000e+00 : f64 - int64_t indices[] = {4, 7}; - int64_t two = 2; + int64_t indices[] = {0, 1}; + int64_t one = 1; MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get( - mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding), + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding), 2, indices); MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( - mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), 2, + mlirRankedTensorTypeGet(1, &one, mlirF32TypeGet(ctx), encoding), 1, floats); MlirAttribute sparseAttr = mlirSparseElementsAttribute( mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), indicesAttr, valuesAttr); mlirAttributeDump(sparseAttr); - // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32> + // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32> return 0; } diff --git a/mlir/test/Dialect/Quant/convert-const.mlir b/mlir/test/Dialect/Quant/convert-const.mlir --- a/mlir/test/Dialect/Quant/convert-const.mlir +++ b/mlir/test/Dialect/Quant/convert-const.mlir @@ -68,15 +68,15 @@ // ----- // Verifies i8 fixedpoint quantization on a sparse tensor, sweeping values. // CHECK-LABEL: const_sparse_tensor_i8_fixedpoint -func @const_sparse_tensor_i8_fixedpoint() -> tensor<7x2xf32> { +func @const_sparse_tensor_i8_fixedpoint() -> tensor<2x7xf32> { // NOTE: Ugly regex match pattern for opening "[[" of indices tensor. - // CHECK: %cst = constant sparse<{{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<7x2xi8> + // CHECK: %cst = constant sparse<{{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<2x7xi8> %cst = constant sparse< [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], - [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7x2xf32> - %1 = "quant.qcast"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<7x2x!quant.uniform>) -> (tensor<7x2xf32>) - return %2 : tensor<7x2xf32> + [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<2x7xf32> + %1 = "quant.qcast"(%cst) : (tensor<2x7xf32>) -> tensor<2x7x!quant.uniform> + %2 = "quant.dcast"(%1) : (tensor<2x7x!quant.uniform>) -> (tensor<2x7xf32>) + return %2 : tensor<2x7xf32> } // ----- diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -83,8 +83,8 @@ %ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16> // Fold an extract into a sparse with a non sparse index. - %2 = constant sparse<[[1, 1, 1]], [-2.0]> : tensor<1x1x1xf16> - %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<1x1x1xf16> + %2 = constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16> + %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16> // Fold an extract into a dense tensor. %3 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32> diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -897,7 +897,7 @@ // ----- func @invalid_tensor_literal() { - // expected-error @+1 {{expected 1-d tensor for values}} + // expected-error @+1 {{expected 1-d tensor for sparse element values}} "foof16"(){bar = sparse<[[0, 0, 0]], [[-2.0]]> : vector<1x1x1xf16>} : () -> () // ----- @@ -908,6 +908,12 @@ // ----- +func @invalid_tensor_literal() { + // expected-error @+1 {{sparse index #0 is not contained within the value shape, with index=[1, 1], and type='tensor<1x1xi16>'}} + "fooi16"(){bar = sparse<1, 10> : tensor<1x1xi16>} : () -> () + +// ----- + func @invalid_affine_structure() { %c0 = constant 0 : index %idx = affine.apply affine_map<(d0, d1)> (%c0, %c0) // expected-error {{expected '->' or ':'}} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -810,7 +810,7 @@ // CHECK: "fooi32"() {bar = sparse<> : tensor<1x1xi32>} : () -> () "fooi32"(){bar = sparse<> : tensor<1x1xi32>} : () -> () // CHECK: "fooi64"() {bar = sparse<0, -1> : tensor<1xi64>} : () -> () - "fooi64"(){bar = sparse<[[0]], [-1]> : tensor<1xi64>} : () -> () + "fooi64"(){bar = sparse<[0], [-1]> : tensor<1xi64>} : () -> () // CHECK: "foo2"() {bar = sparse<> : tensor<0xi32>} : () -> () "foo2"(){bar = sparse<> : tensor<0xi32>} : () -> () // CHECK: "foo3"() {bar = sparse<> : tensor} : () -> () diff --git a/mlir/test/IR/pretty-attributes.mlir b/mlir/test/IR/pretty-attributes.mlir --- a/mlir/test/IR/pretty-attributes.mlir +++ b/mlir/test/IR/pretty-attributes.mlir @@ -11,8 +11,8 @@ // CHECK: dense<[1, 2]> : tensor<2xi32> "test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> () -// CHECK: opaque<"_", "0xDEADBEEF"> : vector<1x1x1xf16> -"test.sparse_attr"() {foo.sparse_attr = sparse<[[1, 2, 3]], -2.0> : vector<1x1x1xf16>} : () -> () +// CHECK: opaque<"_", "0xDEADBEEF"> : vector<1x1x10xf16> +"test.sparse_attr"() {foo.sparse_attr = sparse<[[0, 0, 5]], -2.0> : vector<1x1x10xf16>} : () -> () // CHECK: opaque<"_", "0xDEADBEEF"> : tensor<100xf32> "test.opaque_attr"() {foo.opaque_attr = opaque<"_", "0xEBFE"> : tensor<100xf32> } : () -> () diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1155,7 +1155,7 @@ // CHECK-LABEL: @constants llvm.func @constants() -> vector<4xf32> { // CHECK: ret <4 x float> - %0 = llvm.mlir.constant(sparse<[[0]], [4.2e+01]> : vector<4xf32>) : vector<4xf32> + %0 = llvm.mlir.constant(sparse<[0], [4.2e+01]> : vector<4xf32>) : vector<4xf32> llvm.return %0 : vector<4xf32> }