diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py --- a/mlir/utils/generate-test-checks.py +++ b/mlir/utils/generate-test-checks.py @@ -56,6 +56,12 @@ def pop_name_scope(self): self.scopes.pop() + def num_scopes(self): + return len(self.scopes) + + def clear_counter(self): + self.name_counter = 0 + # Process a line of input that has been split at each SSA identifier '%'. def process_line(line_chunks, variable_namer): @@ -87,6 +93,22 @@ return output_line + '\n' +def process_source_lines(source_lines, note, args): + source_split_re = re.compile(args.source_delim_regex) + + source_segments = [[]] + for line in source_lines: + if line == note: + continue + if line.find(args.check_prefix) != -1: + continue + if source_split_re.search(line): + source_segments.append([]) + + source_segments[-1].append(line + '\n') + return source_segments + + # Pre-process a line of input to remove any character sequences that will be # problematic with FileCheck. def preprocess_line(line): @@ -112,25 +134,51 @@ '--output', nargs='?', type=argparse.FileType('w'), - default=sys.stdout) + default=None) parser.add_argument( 'input', nargs='?', type=argparse.FileType('r'), default=sys.stdin) + parser.add_argument( + '--source', type=str, + help='Print each CHECK chunk before each delimeter line in the source' + 'file, respectively. The delimeter lines are identified by ' + '--source_delim_regex.') + parser.add_argument('--source_delim_regex', type=str, default='func @') + parser.add_argument( + '--starts_from_scope', type=int, default=1, + help='Omit the top specified level of content. For example, by default ' + 'it omits "module {"') + parser.add_argument('-i', '--inplace', action='store_true', default=False) + args = parser.parse_args() # Open the given input file. input_lines = [l.rstrip() for l in args.input] args.input.close() - output_lines = [] - # Generate a note used for the generated check file. script_name = os.path.basename(__file__) autogenerated_note = (ADVERT + 'utils/' + script_name) - output_lines.append(autogenerated_note + '\n') + source_segments = None + if args.source: + source_segments = process_source_lines( + [l.rstrip() for l in open(args.source, 'r')], + autogenerated_note, + args + ) + + if args.inplace: + assert args.output is None + output = open(args.source, 'w') + elif args.output is None: + output = sys.stdout + else: + output = args.output + + output_segments = [[]] # A map containing data used for naming SSA value names. variable_namer = SSAVariableNamer() for input_line in input_lines: @@ -144,17 +192,25 @@ if is_block: input_line = input_line.rsplit('//', 1)[0].rstrip() - # Top-level operations are heuristically the operations at nesting level 1. - is_toplevel_op = (not is_block and input_line.startswith(' ') and - input_line[2] != ' ' and input_line[2] != '}') + cur_level = variable_namer.num_scopes() # If the line starts with a '}', pop the last name scope. if lstripped_input_line[0] == '}': variable_namer.pop_name_scope() + cur_level = variable_namer.num_scopes() # If the line ends with a '{', push a new name scope. if input_line[-1] == '{': variable_namer.push_name_scope() + if cur_level == args.starts_from_scope: + output_segments.append([]) + + # Omit lines at the near top level e.g. "module {". + if cur_level < args.starts_from_scope: + continue + + if len(output_segments[-1]) == 0: + variable_namer.clear_counter() # Preprocess the input to remove any sequences that may be problematic with # FileCheck. @@ -164,7 +220,7 @@ ssa_split = input_line.split('%') # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. - if not is_toplevel_op or not ssa_split[0]: + if len(output_segments[-1]) != 0 or not ssa_split[0]: output_line = '// ' + args.check_prefix + ': ' # Pad to align with the 'LABEL' statements. output_line += (' ' * len('-LABEL')) @@ -176,32 +232,41 @@ output_line += process_line(ssa_split[1:], variable_namer) else: - # Append a newline to the output to separate the logical blocks. - output_lines.append('\n') - output_line = '// ' + args.check_prefix + '-LABEL: ' - # Output the first line chunk that does not contain an SSA name for the # label. - output_line += ssa_split[0] + '\n' + output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n' - # Process the rest of the input line on a separate check line. - if len(ssa_split) > 1: + # Process the rest of the input line on separate check lines. + for argument in ssa_split[1:]: output_line += '// ' + args.check_prefix + '-SAME: ' # Pad to align with the original position in the line. output_line += ' ' * len(ssa_split[0]) # Process the rest of the line. - output_line += process_line(ssa_split[1:], variable_namer) + output_line += process_line([argument], variable_namer) # Append the output line. - output_lines.append(output_line) + output_segments[-1].append(output_line) + + output.write(autogenerated_note + '\n') # Write the output. - for output_line in output_lines: - args.output.write(output_line) - args.output.write('\n') - args.output.close() + if source_segments: + assert len(output_segments) == len(source_segments) + for check_segment, source_segment in zip(output_segments, source_segments): + for line in check_segment: + output.write(line) + for line in source_segment: + output.write(line) + output.write('\n') + else: + for segment in output_segments: + output.write('\n') + for output_line in segment: + output.write(output_line) + output.write('\n') + output.close() if __name__ == '__main__':