diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1432,10 +1432,12 @@ break; } os << "sparse<"; - printDenseIntOrFPElementsAttr(elementsAttr.getIndices(), - /*allowHex=*/false); - os << ", "; - printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); + DenseIntElementsAttr indices = elementsAttr.getIndices(); + if (indices.getNumElements() != 0) { + printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); + os << ", "; + printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); + } os << '>'; break; } @@ -1476,20 +1478,15 @@ // Special case for degenerate tensors. auto numElements = type.getNumElements(); - int64_t rank = type.getRank(); - if (numElements == 0) { - for (int i = 0; i < rank; ++i) - os << '['; - for (int i = 0; i < rank; ++i) - os << ']'; + if (numElements == 0) return; - } // We use a mixed-radix counter to iterate through the shape. When we bump a // non-least-significant digit, we emit a close bracket. When we next emit an // element we re-open all closed brackets. // The mixed-radix counter, with radices in 'shape'. + int64_t rank = type.getRank(); SmallVector counter(rank, 0); // The number of brackets that have been opened and not closed. unsigned openBrackets = 0; diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -758,13 +758,13 @@ if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - // Parse the literal data. + // Parse the literal data if necessary. TensorLiteralParser literalParser(*this); - if (literalParser.parse(/*allowHex=*/true)) - return nullptr; - - if (parseToken(Token::greater, "expected '>'")) - return nullptr; + if (!consumeIf(Token::greater)) { + if (literalParser.parse(/*allowHex=*/true) || + parseToken(Token::greater, "expected '>'")) + return nullptr; + } auto typeLoc = getToken().getLoc(); auto type = parseElementsLiteralType(attrType); @@ -841,6 +841,25 @@ if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; + // Check for the case where all elements are sparse. The indices are + // represented by a 2-dimensional shape where the second dimension is the rank + // of the type. + Type indiceEltType = builder.getIntegerType(64); + if (consumeIf(Token::greater)) { + ShapedType type = parseElementsLiteralType(attrType); + if (!type) + return nullptr; + + // Construct the sparse elements attr using zero element indice/value + // attributes. + ShapedType indicesType = + RankedTensorType::get({0, type.getRank()}, indiceEltType); + ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); + return SparseElementsAttr::get( + type, DenseElementsAttr::get(indicesType, ArrayRef()), + DenseElementsAttr::get(valuesType, ArrayRef())); + } + /// Parse the indices. We don't allow hex values here as we may need to use /// the inferred shape. auto indicesLoc = getToken().getLoc(); @@ -869,7 +888,6 @@ // 2-dimensional shape where the second dimension is the rank of the type. // Given that the parsed indices is a splat, we know that we only have one // indice and thus one for the first dimension. - auto indiceEltType = builder.getIntegerType(64); ShapedType indicesType; if (indiceParser.getShape().empty()) { indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -670,19 +670,21 @@ // CHECK: "fooi67"() {bar = dense<{{\[\[\[}}-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> () "fooi67"(){bar = dense<[[[-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> () -// CHECK: "foo2"() {bar = dense<[]> : tensor<0xi32>} : () -> () - "foo2"(){bar = dense<[]> : tensor<0xi32>} : () -> () -// CHECK: "foo2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xi32>} : () -> () - "foo2"(){bar = dense<[[]]> : tensor<1x0xi32>} : () -> () +// CHECK: "foo2"() {bar = dense<> : tensor<0xi32>} : () -> () + "foo2"(){bar = dense<> : tensor<0xi32>} : () -> () +// CHECK: "foo2"() {bar = dense<> : tensor<1x0xi32>} : () -> () + "foo2"(){bar = dense<> : tensor<1x0xi32>} : () -> () +// CHECK: dense<> : tensor<0x512x512xi32> + "foo2"(){bar = dense<> : tensor<0x512x512xi32>} : () -> () // CHECK: "foo3"() {bar = dense<{{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> () "foo3"(){bar = dense<[[[5, -6, 1, 2]], [[7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> () // CHECK: "float1"() {bar = dense<5.000000e+00> : tensor<1x1x1xf32>} : () -> () "float1"(){bar = dense<[[[5.0]]]> : tensor<1x1x1xf32>} : () -> () -// CHECK: "float2"() {bar = dense<[]> : tensor<0xf32>} : () -> () - "float2"(){bar = dense<[]> : tensor<0xf32>} : () -> () -// CHECK: "float2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xf32>} : () -> () - "float2"(){bar = dense<[[]]> : tensor<1x0xf32>} : () -> () +// CHECK: "float2"() {bar = dense<> : tensor<0xf32>} : () -> () + "float2"(){bar = dense<> : tensor<0xf32>} : () -> () +// CHECK: "float2"() {bar = dense<> : tensor<1x0xf32>} : () -> () + "float2"(){bar = dense<> : tensor<1x0xf32>} : () -> () // CHECK: "bfloat16"() {bar = dense<{{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]> : tensor<2x1x4xbf16>} : () -> () "bfloat16"(){bar = dense<[[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]> : tensor<2x1x4xbf16>} : () -> () @@ -752,27 +754,27 @@ "fooi8"(){bar = sparse<0, -2> : tensor<1x1x1xi8>} : () -> () // CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : tensor<2x2x2xi16>} : () -> () "fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : tensor<2x2x2xi16>} : () -> () -// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x1xi32>} : () -> () - "fooi32"(){bar = sparse<[], []> : tensor<1x1xi32>} : () -> () +// CHECK: "fooi32"() {bar = sparse<> : tensor<1x1xi32>} : () -> () + "fooi32"(){bar = sparse<> : tensor<1x1xi32>} : () -> () // CHECK: "fooi64"() {bar = sparse<0, -1> : tensor<1xi64>} : () -> () "fooi64"(){bar = sparse<[[0]], [-1]> : tensor<1xi64>} : () -> () -// CHECK: "foo2"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xi32>} : () -> () - "foo2"(){bar = sparse<[], []> : tensor<0xi32>} : () -> () -// CHECK: "foo3"() {bar = sparse<{{\[}}], {{\[}}]> : tensor} : () -> () - "foo3"(){bar = sparse<[], []> : tensor} : () -> () +// CHECK: "foo2"() {bar = sparse<> : tensor<0xi32>} : () -> () + "foo2"(){bar = sparse<> : tensor<0xi32>} : () -> () +// CHECK: "foo3"() {bar = sparse<> : tensor} : () -> () + "foo3"(){bar = sparse<> : tensor} : () -> () // CHECK: "foof16"() {bar = sparse<0, -2.000000e+00> : tensor<1x1x1xf16>} : () -> () "foof16"(){bar = sparse<0, -2.0> : tensor<1x1x1xf16>} : () -> () // CHECK: "foobf16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]> : tensor<2x2x2xbf16>} : () -> () "foobf16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2.0, -1.0, 5.0]> : tensor<2x2x2xbf16>} : () -> () -// CHECK: "foof32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x0x1xf32>} : () -> () - "foof32"(){bar = sparse<[], []> : tensor<1x0x1xf32>} : () -> () +// CHECK: "foof32"() {bar = sparse<> : tensor<1x0x1xf32>} : () -> () + "foof32"(){bar = sparse<> : tensor<1x0x1xf32>} : () -> () // CHECK: "foof64"() {bar = sparse<0, -1.000000e+00> : tensor<1xf64>} : () -> () "foof64"(){bar = sparse<[[0]], [-1.0]> : tensor<1xf64>} : () -> () -// CHECK: "foof320"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xf32>} : () -> () - "foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> () -// CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor} : () -> () - "foof321"(){bar = sparse<[], []> : tensor} : () -> () +// CHECK: "foof320"() {bar = sparse<> : tensor<0xf32>} : () -> () + "foof320"(){bar = sparse<> : tensor<0xf32>} : () -> () +// CHECK: "foof321"() {bar = sparse<> : tensor} : () -> () + "foof321"(){bar = sparse<> : tensor} : () -> () // CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () @@ -789,8 +791,8 @@ "fooi8"(){bar = sparse<0, -2> : vector<1x1x1xi8>} : () -> () // CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : vector<2x2x2xi16>} : () -> () "fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : vector<2x2x2xi16>} : () -> () -// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : vector<1x1xi32>} : () -> () - "fooi32"(){bar = sparse<[], []> : vector<1x1xi32>} : () -> () +// CHECK: "fooi32"() {bar = sparse<> : vector<1x1xi32>} : () -> () + "fooi32"(){bar = sparse<> : vector<1x1xi32>} : () -> () // CHECK: "fooi64"() {bar = sparse<0, -1> : vector<1xi64>} : () -> () "fooi64"(){bar = sparse<[[0]], [-1]> : vector<1xi64>} : () -> ()