diff --git a/clang-tools-extra/clangd/quality/CompletionModelCodegen.py b/clang-tools-extra/clangd/quality/CompletionModelCodegen.py --- a/clang-tools-extra/clangd/quality/CompletionModelCodegen.py +++ b/clang-tools-extra/clangd/quality/CompletionModelCodegen.py @@ -1,7 +1,7 @@ """Code generator for Code Completion Model Inference. Tool runs on the Decision Forest model defined in {model} directory. -It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp +It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp The generated files defines the Example class named {cpp_class} having all the features as class members. The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate. """ @@ -39,34 +39,32 @@ def boost_node(n, label, next_label): - """Returns code snippet for a leaf/boost node. - Adds value of leaf to the score and jumps to the root of the next tree.""" - return "%s: Score += %s; goto %s;" % ( - label, n['score'], next_label) + """Returns code snippet for a leaf/boost node.""" + return "%s: return %s;" % (label, n['score']) def if_greater_node(n, label, next_label): """Returns code snippet for a if_greater node. - Jumps to true_label if the Example feature (NUMBER) is greater than the threshold. - Comparing integers is much faster than comparing floats. Assuming floating points + Jumps to true_label if the Example feature (NUMBER) is greater than the threshold. + Comparing integers is much faster than comparing floats. Assuming floating points are represented as IEEE 754, it order-encodes the floats to integers before comparing them. Control falls through if condition is evaluated to false.""" threshold = n["threshold"] - return "%s: if (E.%s >= %s /*%s*/) goto %s;" % ( - label, n['feature'], order_encode(threshold), threshold, next_label) + return "%s: if (%s >= %s /*%s*/) goto %s;" % ( + label, n['feature'], order_encode(threshold), threshold, next_label) def if_member_node(n, label, next_label): """Returns code snippet for a if_member node. - Jumps to true_label if the Example feature (ENUM) is present in the set of enum values + Jumps to true_label if the Example feature (ENUM) is present in the set of enum values described in the node. Control falls through if condition is evaluated to false.""" members = '|'.join([ "BIT(%s_type::%s)" % (n['feature'], member) for member in n["set"] ]) - return "%s: if (E.%s & (%s)) goto %s;" % ( - label, n['feature'], members, next_label) + return "%s: if (%s & (%s)) goto %s;" % ( + label, n['feature'], members, next_label) def node(n, label, next_label): @@ -94,8 +92,6 @@ """ label = "t%d_n%d" % (tree_num, node_num) code = [] - if node_num == 0: - code.append("t%d:" % tree_num) if t["operation"] == "boost": code.append(node(t, label=label, next_label="t%d" % (tree_num + 1))) @@ -115,11 +111,11 @@ return code+false_code+true_code, 1+false_size+true_size -def gen_header_code(features_json, cpp_class, filename): +def gen_header_code(features_json, cpp_class, filename, num_trees): """Returns code for header declaring the inference runtime. Declares the Example class named {cpp_class} inside relevant namespaces. - The Example class contains all the features as class members. This + The Example class contains all the features as class members. This class can be used to represent a code completion candidate. Provides `float Evaluate()` function which can be used to score the Example. """ @@ -145,7 +141,6 @@ return """#ifndef %s #define %s #include -#include "llvm/Support/Compiler.h" %s class %s { @@ -158,18 +153,21 @@ // Produces an integer that sorts in the same order as F. // That is: a < b <==> orderEncode(a) < orderEncode(b). static uint32_t OrderEncode(float F); + + // Evaluation functions for each tree. +%s + friend float Evaluate(const %s&); }; -// The function may have large number of lines of code. MSAN -// build times out in such case. -LLVM_NO_SANITIZE("memory") float Evaluate(const %s&); %s #endif // %s """ % (guard, guard, cpp_class.ns_begin(), cpp_class.name, nline.join(setters), - nline.join(class_members), cpp_class.name, cpp_class.name, - cpp_class.ns_end(), guard) + nline.join(class_members), + nline.join([" float EvaluateTree%d() const;" % tree_num + for tree_num in range(num_trees)]), + cpp_class.name, cpp_class.name, cpp_class.ns_end(), guard) def order_encode(v): @@ -181,22 +179,32 @@ return TopBit + i # top half of integers -def evaluate_func(forest_json, cpp_class): - """Generates code for `float Evaluate(const {Example}&)` function. - The generated function can be used to score an Example.""" - code = "float Evaluate(const %s& E) {\n" % cpp_class.name - lines = [] - lines.append("float Score = 0;") +def evaluate_funcs(forest_json, cpp_class): + """Generates evaluation functions for each tree and combines them in + `float Evaluate(const {Example}&)` function. This function can be + used to score an Example.""" + + code = "" + + # Generate evaluation function of each tree. tree_num = 0 for tree_json in forest_json: - lines.extend(tree(tree_json, tree_num=tree_num, node_num=0)[0]) - lines.append("") + code += "float %s::EvaluateTree%d() const {\n" % ( + cpp_class.name, tree_num) + code += " " + \ + " \n".join( + tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n" + code += "}\n\n" tree_num += 1 - lines.append("t%s: // No such tree." % len(forest_json)) - lines.append("return Score;") - code += " " + "\n ".join(lines) - code += "\n}" + # Combine the scores of all trees in the final function. + code += "float Evaluate(const %s& E) {\n" % cpp_class.name + code += " float Score = 0;\n" + for tree_num in range(len(forest_json)): + code += " Score += E.EvaluateTree%d();\n" % tree_num + code += " return Score;\n" + code += "}\n" + return code @@ -248,7 +256,7 @@ %s %s """ % (nl.join(angled_include), nl.join(quoted_include), cpp_class.ns_begin(), - using_decls, cpp_class.name, evaluate_func(forest_json, cpp_class), + using_decls, cpp_class.name, evaluate_funcs(forest_json, cpp_class), cpp_class.ns_end()) @@ -287,7 +295,10 @@ with open(header_file, 'w+t') as output_h: output_h.write(gen_header_code( - features_json=features_json, cpp_class=cpp_class, filename=filename)) + features_json=features_json, + cpp_class=cpp_class, + filename=filename, + num_trees=len(forest_json))) if __name__ == '__main__':