diff --git a/llvm/utils/UpdateTestChecks/common.py b/llvm/utils/UpdateTestChecks/common.py --- a/llvm/utils/UpdateTestChecks/common.py +++ b/llvm/utils/UpdateTestChecks/common.py @@ -53,7 +53,9 @@ self.run_lines = find_run_lines(test, self.input_lines) self.comment_prefix = comment_prefix if self.comment_prefix is None: - if self.path.endswith('.mir'): + if self.path.endswith('.mlir'): + self.comment_prefix = '//' + elif self.path.endswith('.mir'): self.comment_prefix = '#' else: self.comment_prefix = ';' diff --git a/llvm/utils/update_mlir_test_checks.py b/llvm/utils/update_mlir_test_checks.py new file mode 100755 --- /dev/null +++ b/llvm/utils/update_mlir_test_checks.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 + +"""A script to generate FileCheck statements for 'mlir_opt' regression tests. + +This script is a utility to update MLIR opt test cases with new +FileCheck patterns. It can either update all of the tests in the file or +a single test function. + +Example usage: +$ update_mlir_test_checks.py --opt=../bin/opt test/foo.mlir + +Workflow: +1. Make a compiler patch that requires updating some number of FileCheck lines + in regression test files. +2. Save the patch and revert it from your local work area. +3. Update the RUN-lines in the affected regression tests to look canonical. + Example: "; RUN: opt < %s -instcombine -S | FileCheck %s" +4. Refresh the FileCheck lines for either the entire file or select functions by + running this script. +5. Commit the fresh baseline of checks. +6. Apply your patch from step 1 and rebuild your local binaries. +7. Re-run this script on affected regression tests. +8. Check the diffs to ensure the script has done something reasonable. +9. Submit a patch including the regression test diffs for review. + +A common pattern is to have the script insert complete checking of every +instruction. Then, edit it down to only check the relevant instructions. +The script is designed to make adding checks to a test case fast, it is *not* +designed to be authoratitive about what constitutes a good test! +""" + +from __future__ import print_function + +import argparse +import glob +import itertools +import os # Used to advertise this file's name ("autogenerated_note"). +import string +import subprocess +import sys +import tempfile +import re + +from UpdateTestChecks import common + + +# Regex command to match an SSA identifier. +SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*' +SSA_RE = re.compile(SSA_RE_STR) + +MLIROPT_FUNCTION_RE = re.compile( + r'^\s*func\s+[^@]*@(?P[\w.-]+?)\s*' + r'(?P\((\)|(.*?[\w.-]+?)\))[^{]*)\{\n(?P.*?)^\}$', + flags=(re.M | re.S)) + +MLIR_FUNCTION_RE = re.compile(r'^\s*func\s+[^@]*@([\w.-]+)\s*\(') + +# Class used to generate and manage string substitution blocks for SSA value +# names. +class SSAVariableNamer: + + def __init__(self): + self.scopes = [] + self.name_counter = 0 + + # Generate a substitution name for the given ssa value name. + def generate_name(self, ssa_name): + variable = 'VAL_' + str(self.name_counter) + self.name_counter += 1 + self.scopes[-1][ssa_name] = variable + return variable + + # Push a new variable name scope. + def push_name_scope(self): + self.scopes.append({}) + + # Pop the last variable name scope. + def pop_name_scope(self): + self.scopes.pop() + + +# Process a line of input that has been split at each SSA identifier '%'. +def process_line(line_chunks, variable_namer): + output_line = '' + + # Process the rest that contained an SSA value name. + for chunk in line_chunks: + m = SSA_RE.match(chunk) + ssa_name = m.group(0) + + # Check if an existing variable exists for this name. + variable = None + for scope in variable_namer.scopes: + variable = scope.get(ssa_name) + if variable is not None: + break + + # If one exists, then output the existing name. + if variable is not None: + output_line += '%[[' + variable + ']]' + else: + # Otherwise, generate a new variable. + variable = variable_namer.generate_name(ssa_name) + output_line += '%[[' + variable + ':.*]]' + + # Append the non named group. + output_line += chunk[len(ssa_name):] + + return output_line + + +# Pre-process a line of input to remove any character sequences that will be +# problematic with FileCheck. +def preprocess_line(line): + # Replace any double brackets, '[[' with escaped replacements. '[[' + # corresponds to variable names in FileCheck. + output_line = line.replace('[[', '{{\\[\\[}}') + + # Replace any single brackets that are followed by an SSA identifier, the + # identifier will be replace by a variable; Creating the same situation as + # above. + output_line = output_line.replace('[%', '{{\\[}}%') + + return output_line + +# Hacked from generate-test-checks.py. Not sure how to do this with +# common.add_ir_checks +def add_mlir_checks(output_lines, args, raw_tool_output, variable_namer): + # A map containing data used for naming SSA value names. + variable_namer = SSAVariableNamer() + input_lines = raw_tool_output.split('\n') + for input_line in input_lines: + lstripped_input_line = input_line.lstrip() + + if not lstripped_input_line: + continue + + # Lines with blocks begin with a ^. These lines have a trailing comment + # that needs to be stripped. + is_block = lstripped_input_line[0] == '^' + if is_block: + input_line = input_line.rsplit('//', 1)[0].rstrip() + + # Top-level operations are heuristically the operations at nesting level 1. + is_toplevel_op = (not is_block and input_line.startswith(' ') and + input_line[2] != ' ' and input_line[2] != '}') + + # If the line starts with a '}', pop the last name scope. + if lstripped_input_line[0] == '}': + variable_namer.pop_name_scope() + + # If the line ends with a '{', push a new name scope. + if input_line[-1] == '{': + variable_namer.push_name_scope() + + # Preprocess the input to remove any sequences that may be problematic with + # FileCheck. + input_line = preprocess_line(input_line) + + # Split the line at the each SSA value name. + ssa_split = input_line.split('%') + + # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. + if not is_toplevel_op or not ssa_split[0]: + output_line = '// ' + args.check_prefix + ': ' + # Pad to align with the 'LABEL' statements. + output_line += (' ' * len('-LABEL')) + + # Output the first line chunk that does not contain an SSA name. + output_line += ssa_split[0] + + # Process the rest of the input line. + output_line += process_line(ssa_split[1:], variable_namer) + + else: + # Append a newline to the output to separate the logical blocks. + output_lines.append('') + output_line = '// ' + args.check_prefix + '-LABEL: ' + + # Output the first line chunk that does not contain an SSA name for the + # label. + output_line += ssa_split[0] + '\n' + + # Process the rest of the input line on a separate check line. + if len(ssa_split) > 1: + output_line += '// ' + args.check_prefix + '-SAME: ' + + # Pad to align with the original position in the line. + output_line += ' ' * len(ssa_split[0]) + + # Process the rest of the line. + output_line += process_line(ssa_split[1:], variable_namer) + # If it's outside a function, it just gets copied to the output. + output_lines.append(output_line) + +def main(): + from argparse import RawTextHelpFormatter + parser = argparse.ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) + parser.add_argument('--opt-binary', default='mlir-opt', + help='The opt binary used to generate the test case') + parser.add_argument( + '--function', help='The function in the test file to update') + parser.add_argument('-p', '--preserve-names', action='store_true', + help='Do not scrub IR names') + parser.add_argument('--function-signature', action='store_true', + help='Keep function signature information around for the check line') + parser.add_argument('--scrub-attributes', action='store_true', + help='Remove attribute annotations (#0) from the end of check line') + parser.add_argument( + '--check-prefix', default='CHECK', help='Prefix to use from check file.') + parser.add_argument('tests', nargs='+') + initial_args = common.parse_commandline_args(parser) + + script_name = os.path.basename(__file__) + opt_basename = os.path.basename(initial_args.opt_binary) + if not re.match(r'^(.)*opt(-\d+)?$', opt_basename): + common.error('Unexpected opt name: ' + opt_basename) + sys.exit(1) + + for ti in common.itertests(initial_args.tests, parser, + script_name='utils/' + script_name): + # If requested we scrub trailing attribute annotations, e.g., '#0', together with whitespaces + if ti.args.scrub_attributes: + common.SCRUB_TRAILING_WHITESPACE_TEST_RE = common.SCRUB_TRAILING_WHITESPACE_AND_ATTRIBUTES_RE + else: + common.SCRUB_TRAILING_WHITESPACE_TEST_RE = common.SCRUB_TRAILING_WHITESPACE_RE + + prefix_list = [] + for l in ti.run_lines: + if '|' not in l: + common.warn('Skipping unparseable RUN line: ' + l) + continue + + (tool_cmd, filecheck_cmd) = tuple([cmd.strip() for cmd in l.split('|', 1)]) + common.verify_filecheck_prefixes(filecheck_cmd) + if not tool_cmd.startswith(opt_basename + ' '): + common.warn('Skipping non-%s RUN line: %s' % (opt_basename, l)) + continue + + if not filecheck_cmd.startswith('FileCheck '): + common.warn('Skipping non-FileChecked RUN line: ' + l) + continue + + tool_cmd_args = tool_cmd[len(opt_basename):].strip() + tool_cmd_args = tool_cmd_args.replace('< %s', '').replace('%s', '').strip() + + check_prefixes = [item for m in + common.CHECK_PREFIX_RE.finditer(filecheck_cmd) + for item in m.group(1).split(',')] + if not check_prefixes: + check_prefixes = ['CHECK'] + + # FIXME: We should use multiple check prefixes to common check lines. For + # now, we just ignore all but the last. + prefix_list.append((check_prefixes, tool_cmd_args)) + + func_dict = {} + for prefixes, _ in prefix_list: + for prefix in prefixes: + func_dict.update({prefix: dict()}) + for prefixes, opt_args in prefix_list: + common.debug('Extracted opt cmd: ' + opt_basename + ' ' + opt_args) + common.debug('Extracted FileCheck prefixes: ' + str(prefixes)) + + raw_tool_output = common.invoke_tool(ti.args.opt_binary, opt_args, ti.path) + common.build_function_body_dictionary( + MLIROPT_FUNCTION_RE, common.scrub_body, [], + raw_tool_output, prefixes, func_dict, ti.args.verbose, + ti.args.function_signature) + + is_in_function = False + is_in_function_start = False + prefix_set = set([prefix for prefixes, _ in prefix_list for prefix in prefixes]) + common.debug('Rewriting FileCheck prefixes:', str(prefix_set)) + output_lines = [] + + # A map containing data used for naming SSA value names. + variable_namer = SSAVariableNamer() + for input_line_info in ti.iterlines(output_lines): + input_line = input_line_info.line + args = input_line_info.args + if is_in_function_start: + if input_line == '': + continue + if input_line.lstrip().startswith('//'): + m = common.CHECK_RE.match(input_line) + if not m or m.group(1) not in prefix_set: + output_lines.append(input_line) + continue + + # # Print out the various check lines here. + # common.add_ir_checks(output_lines, '//', prefix_list, func_dict, + # func_name, args.preserve_names, args.function_signature) + add_mlir_checks(output_lines, args, raw_tool_output, variable_namer) + + is_in_function_start = False + + if is_in_function: + if common.should_add_line_to_output(input_line, prefix_set): + # This input line of the function body will go as-is into the output. + # Except make leading whitespace uniform: 2 spaces. +# input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r' ', input_line) + output_lines.append(input_line) + else: + continue + if input_line.strip() == '}': + is_in_function = False + continue + + # If it's outside a function, it just gets copied to the output. + output_lines.append(input_line) + + m = MLIR_FUNCTION_RE.match(input_line) + if not m: + continue + func_name = m.group(1) + if args.function is not None and func_name != args.function: + # When filtering on a specific function, skip all others. + continue + is_in_function = is_in_function_start = True + + common.debug('Writing %d lines to %s...' % (len(output_lines), ti.path)) + + with open(ti.path, 'wb') as f: + f.writelines(['{}\n'.format(l).encode('utf-8') for l in output_lines]) + + +if __name__ == '__main__': + main()