diff --git a/mlir/utils/gdb-scripts/prettyprinters.py b/mlir/utils/gdb-scripts/prettyprinters.py new file mode 100644 --- /dev/null +++ b/mlir/utils/gdb-scripts/prettyprinters.py @@ -0,0 +1,111 @@ +"""GDB pretty printers for MLIR types.""" + +import gdb.printing + +class StructPrinter: + """Prints (optionally) bases of a struct and it's fields.""" + + def __init__(self, val, print_base = True): + self.val = val + self.print_base = print_base + + def children(self): + for field in self.val.type.values(): + if field.is_base_class: + if self.print_base: + yield (field.name, self.val.cast(field.type)) + else: + yield (field.name, self.val[field.name]) + +def get_default_or_struct_printer(val, print_base = True): + """Returns gdb.default_visualizer(val) or fall back to StructPrinter.""" + default_printer = gdb.default_visualizer(val) + if default_printer: + return default_printer + return StructPrinter(val, print_base) + +class FieldPrinterFactory: + """Returns printer of one of the fields.""" + + def __init__(self, name): + self.name = name + + def __call__(self, val): + field = val[self.name] + if field.type.code == gdb.TYPE_CODE_PTR: + field = field.dereference() + return get_default_or_struct_printer(field) + +def get_first_member_printer(val): + """Returns printer for the first member of val.""" + for field in val.type.fields(): + if not field.is_base_class: + return FieldPrinterFactory(field.name)(val) + return None + +class CastPrinterFactory: + """Returns printer after casting the value.""" + + def __init__(self, type): + self.type = type + + def __call__(self, val): + if not self.type: + return None + return get_default_or_struct_printer(val.cast(self.type)) + +def get_first_base_printer(val): + """Returns printer for the first base class of val.""" + for field in val.type.fields(): + if field.is_base_class: + return CastPrinterFactory(field.type)(val) + return None + +class IdentifierPrinter: + """Prints an mlir::Identifier instance.""" + + def __init__(self, val): + self.val = val + + def to_string(self): + return self.val['pointer'].string() + + def display_hint(self): + return 'string' + +def get_operation_name_printer(val): + """Returns printer for an mlir::OperationName instance.""" + rep_printer = gdb.default_visualizer(val['representation']) + if not rep_printer: + return None + if not hasattr(rep_printer, 'pointer'): + return rep_printer + pointer = rep_printer.pointer + if pointer.type.code == gdb.TYPE_CODE_PTR: + pointer = pointer.dereference() + return get_default_or_struct_printer(pointer) + +def get_storage_user_base_printer(val): + """Returns printer for an mlir::detail::StorageUserBase instance.""" + if not val['impl']: + return None + storage_type = val.type.template_argument(2) + impl = val['impl'].dereference().cast(storage_type) + return get_default_or_struct_printer(impl) + +pp = gdb.printing.RegexpCollectionPrettyPrinter('MLIRSupport') + +# Forwarding printers to only relevant field. +pp.add_printer('mlir::AbstractOperation', '^mlir::AbstractOperation$', FieldPrinterFactory('name')) +for name in ['Location', 'Attribute', 'NamedAttributeList', 'Value']: + pp.add_printer('mlir::%s' % name, '^mlir::%s' % name, get_first_member_printer) + +# Forwarding printers to the only relevant base class. +for name in ['DictionaryAttr', 'ElementsAttr', 'EnumAttr', 'EnumCaseAttr', 'LocationAttr', 'StructAttr']: + pp.add_printer('mlir::%s' % name, '^mlir::%s' % name, get_first_base_printer) + +pp.add_printer('mlir::Identifier', '^mlir::Identifier$', IdentifierPrinter) +pp.add_printer('mlir::OperationName', '^mlir::OperationName$', get_operation_name_printer) +pp.add_printer('mlir::detail::StorageUserBase', '^mlir::detail::StorageUserBase<.*>', get_storage_user_base_printer) + +gdb.printing.register_pretty_printer(gdb.current_objfile(), pp, True)