diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -976,6 +976,11 @@ /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); + /// Print a dense elements attribute. If 'allowHex' is true, a hex string is + /// used instead of individual elements when the elements attr is large. + void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, + bool allowHex); + void printDialectAttribute(Attribute attr); void printDialectType(Type type); @@ -1396,13 +1401,13 @@ break; } case StandardAttributes::DenseIntOrFPElements: { - auto eltsAttr = attr.cast(); + auto eltsAttr = attr.cast(); if (printerFlags.shouldElideElementsAttr(eltsAttr)) { printElidedElementsAttr(os); break; } os << "dense<"; - printDenseElementsAttr(eltsAttr, /*allowHex=*/true); + printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true); os << '>'; break; } @@ -1425,7 +1430,8 @@ break; } os << "sparse<"; - printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false); + printDenseIntOrFPElementsAttr(elementsAttr.getIndices(), + /*allowHex=*/false); os << ", "; printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); os << '>'; @@ -1477,6 +1483,17 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { + if (auto stringAttr = attr.dyn_cast()) { + printDenseStringElementsAttr(stringAttr); + return; + } + + auto numAttr = attr.dyn_cast(); + printDenseIntOrFPElementsAttr(numAttr, allowHex); +} + +void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, + bool allowHex) { auto type = attr.getType(); auto shape = type.getShape(); auto rank = type.getRank(); diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -764,6 +764,11 @@ "foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> () // CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor} : () -> () "foof321"(){bar = sparse<[], []> : tensor} : () -> () + +// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () + "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () +// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> () + "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> () return }