diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py --- a/mlir/utils/generate-test-checks.py +++ b/mlir/utils/generate-test-checks.py @@ -45,20 +45,60 @@ SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" SSA_RE = re.compile(SSA_RE_STR) +# Regex matching the left-hand side of an assignment +SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' +SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) + +# Regex matching attributes +ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)' +ATTR_RE = re.compile(ATTR_RE_STR) + +# Regex matching the left-hand side of an attribute definition +ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*=' +ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR) + # Class used to generate and manage string substitution blocks for SSA value # names. -class SSAVariableNamer: - def __init__(self): +class VariableNamer: + def __init__(self, variable_names): self.scopes = [] self.name_counter = 0 + # Number of variable names to still generate in parent scope + self.generate_in_parent_scope_left = 0 + + # Parse variable names + self.variable_names = [name.upper() for name in variable_names.split(',')] + self.used_variable_names = set() + + # Generate the following 'n' variable names in the parent scope. + def generate_in_parent_scope(self, n): + self.generate_in_parent_scope_left = n + # Generate a substitution name for the given ssa value name. - def generate_name(self, ssa_name): - variable = "VAL_" + str(self.name_counter) - self.name_counter += 1 - self.scopes[-1][ssa_name] = variable - return variable + def generate_name(self, source_variable_name): + + # Compute variable name + variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' + if variable_name == '': + variable_name = "VAL_" + str(self.name_counter) + self.name_counter += 1 + + # Scope where variable name is saved + scope = len(self.scopes) - 1 + if self.generate_in_parent_scope_left > 0: + self.generate_in_parent_scope_left -= 1 + scope = len(self.scopes) - 2 + assert(scope >= 0) + + # Save variable + if variable_name in self.used_variable_names: + raise RuntimeError(variable_name + ': duplicate variable name') + self.scopes[scope][source_variable_name] = variable_name + self.used_variable_names.add(variable_name) + + return variable_name # Push a new variable name scope. def push_name_scope(self): @@ -76,6 +116,46 @@ def clear_counter(self): self.name_counter = 0 +class AttributeNamer: + + def __init__(self, attribute_names): + self.name_counter = 0 + self.attribute_names = [name.upper() for name in attribute_names.split(',')] + self.map = {} + self.used_attribute_names = set() + + # Generate a substitution name for the given attribute name. + def generate_name(self, source_attribute_name): + + # Compute FileCheck name + attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else '' + if attribute_name == '': + attribute_name = "ATTR_" + str(self.name_counter) + self.name_counter += 1 + + # Prepend global symbol + attribute_name = '$' + attribute_name + + # Save attribute + if attribute_name in self.used_attribute_names: + raise RuntimeError(attribute_name + ': duplicate attribute name') + self.map[source_attribute_name] = attribute_name + self.used_attribute_names.add(attribute_name) + return attribute_name + + # Get the saved substitution name for the given attribute name. If no name + # has been generated for the given attribute yet, the source attribute name + # itself is returned. + def get_name(self, source_attribute_name): + return self.map[source_attribute_name] if source_attribute_name in self.map else '?' + +# Return the number of SSA results in a line of type +# %0, %1, ... = ... +# The function returns 0 if there are no results. +def get_num_ssa_results(input_line): + m = SSA_RESULTS_RE.match(input_line) + return m.group().count('%') if m else 0 + # Process a line of input that has been split at each SSA identifier '%'. def process_line(line_chunks, variable_namer): @@ -84,7 +164,7 @@ # Process the rest that contained an SSA value name. for chunk in line_chunks: m = SSA_RE.match(chunk) - ssa_name = m.group(0) + ssa_name = m.group(0) if m is not None else '' # Check if an existing variable exists for this name. variable = None @@ -126,6 +206,25 @@ source_segments[-1].append(line + "\n") return source_segments +def process_attribute_definition(line, attribute_namer, output): + m = ATTR_DEF_RE.match(line) + if m: + attribute_name = attribute_namer.generate_name(m.group(1)) + line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n' + output.write(line) + +def process_attribute_references(line, attribute_namer): + + output_line = '' + components = ATTR_RE.split(line) + for component in components: + m = ATTR_RE.match(component) + if m: + output_line += '#[[' + attribute_namer.get_name(m.group(1)) + ']]' + output_line += component[len(m.group()):] + else: + output_line += component + return output_line # Pre-process a line of input to remove any character sequences that will be # problematic with FileCheck. @@ -171,6 +270,20 @@ 'it omits "module {"', ) parser.add_argument("-i", "--inplace", action="store_true", default=False) + parser.add_argument( + "--variable_names", + type=str, + default='', + help="Names to be used in FileCheck regular expression to represent SSA " + "variables in the order they are encountered. Separate names with commas, " + "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')") + parser.add_argument( + "--attribute_names", + type=str, + default='', + help="Names to be used in FileCheck regular expression to represent " + "attributes in the order they are defined. Separate names with commas," + "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')") args = parser.parse_args() @@ -197,15 +310,22 @@ output = args.output output_segments = [[]] - # A map containing data used for naming SSA value names. - variable_namer = SSAVariableNamer() + + # Namers + variable_namer = VariableNamer(args.variable_names) + attribute_namer = AttributeNamer(args.attribute_names) + + # Process lines for input_line in input_lines: if not input_line: continue - lstripped_input_line = input_line.lstrip() + + # Check if this is an attribute definition and process it + process_attribute_definition(input_line, attribute_namer, output) # Lines with blocks begin with a ^. These lines have a trailing comment # that needs to be stripped. + lstripped_input_line = input_line.lstrip() is_block = lstripped_input_line[0] == "^" if is_block: input_line = input_line.rsplit("//", 1)[0].rstrip() @@ -222,6 +342,10 @@ variable_namer.push_name_scope() if cur_level == args.starts_from_scope: output_segments.append([]) + + # Result SSA values must still be pushed to parent scope + num_ssa_results = get_num_ssa_results(input_line) + variable_namer.generate_in_parent_scope(num_ssa_results) # Omit lines at the near top level e.g. "module {". if cur_level < args.starts_from_scope: @@ -234,6 +358,9 @@ # FileCheck. input_line = preprocess_line(input_line) + # Process uses of attributes in this line + input_line = process_attribute_references(input_line, attribute_namer) + # Split the line at the each SSA value name. ssa_split = input_line.split("%")