|
| 1 | +"""Code generator for Code Completion Model Inference. |
| 2 | +
|
| 3 | +Tool runs on the Decision Forest model defined in {model} directory. |
| 4 | +It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp |
| 5 | +The generated files defines the Example class named {cpp_class} having all the features as class members. |
| 6 | +The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate. |
| 7 | +""" |
| 8 | + |
| 9 | +import argparse |
| 10 | +import json |
| 11 | +import struct |
| 12 | +from enum import Enum |
| 13 | + |
| 14 | + |
| 15 | +class CppClass: |
| 16 | + """Holds class name and names of the enclosing namespaces.""" |
| 17 | + |
| 18 | + def __init__(self, cpp_class): |
| 19 | + ns_and_class = cpp_class.split("::") |
| 20 | + self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0] |
| 21 | + self.name = ns_and_class[-1] |
| 22 | + if len(self.name) == 0: |
| 23 | + raise ValueError("Empty class name.") |
| 24 | + |
| 25 | + def ns_begin(self): |
| 26 | + """Returns snippet for opening namespace declarations.""" |
| 27 | + open_ns = [f"namespace {ns} {{" for ns in self.ns] |
| 28 | + return "\n".join(open_ns) |
| 29 | + |
| 30 | + def ns_end(self): |
| 31 | + """Returns snippet for closing namespace declarations.""" |
| 32 | + close_ns = [ |
| 33 | + f"}} // namespace {ns}" for ns in reversed(self.ns)] |
| 34 | + return "\n".join(close_ns) |
| 35 | + |
| 36 | + |
| 37 | +def header_guard(filename): |
| 38 | + '''Returns the header guard for the generated header.''' |
| 39 | + return f"GENERATED_DECISION_FOREST_MODEL_{filename.upper()}_H" |
| 40 | + |
| 41 | + |
| 42 | +def boost_node(n, label, next_label): |
| 43 | + """Returns code snippet for a leaf/boost node. |
| 44 | + Adds value of leaf to the score and jumps to the root of the next tree.""" |
| 45 | + return f"{label}: Score += {n['score']}; goto {next_label};" |
| 46 | + |
| 47 | + |
| 48 | +def if_greater_node(n, label, next_label): |
| 49 | + """Returns code snippet for a if_greater node. |
| 50 | + Jumps to true_label if the Example feature (NUMBER) is greater than the threshold. |
| 51 | + Comparing integers is much faster than comparing floats. Assuming floating points |
| 52 | + are represented as IEEE 754, it order-encodes the floats to integers before comparing them. |
| 53 | + Control falls through if condition is evaluated to false.""" |
| 54 | + threshold = n["threshold"] |
| 55 | + return f"{label}: if (E.{n['feature']} >= {order_encode(threshold)} /*{threshold}*/) goto {next_label};" |
| 56 | + |
| 57 | + |
| 58 | +def if_member_node(n, label, next_label): |
| 59 | + """Returns code snippet for a if_member node. |
| 60 | + Jumps to true_label if the Example feature (ENUM) is present in the set of enum values |
| 61 | + described in the node. |
| 62 | + Control falls through if condition is evaluated to false.""" |
| 63 | + members = '|'.join([ |
| 64 | + f"BIT({n['feature']}_type::{member})" |
| 65 | + for member in n["set"] |
| 66 | + ]) |
| 67 | + return f"{label}: if (E.{n['feature']} & ({members})) goto {next_label};" |
| 68 | + |
| 69 | + |
| 70 | +def node(n, label, next_label): |
| 71 | + """Returns code snippet for the node.""" |
| 72 | + return { |
| 73 | + 'boost': boost_node, |
| 74 | + 'if_greater': if_greater_node, |
| 75 | + 'if_member': if_member_node, |
| 76 | + }[n['operation']](n, label, next_label) |
| 77 | + |
| 78 | + |
| 79 | +def tree(t, tree_num: int, node_num: int): |
| 80 | + """Returns code for inferencing a Decision Tree. |
| 81 | + Also returns the size of the decision tree. |
| 82 | +
|
| 83 | + A tree starts with its label `t{tree#}`. |
| 84 | + A node of the tree starts with label `t{tree#}_n{node#}`. |
| 85 | +
|
| 86 | + The tree contains two types of node: Conditional node and Leaf node. |
| 87 | + - Conditional node evaluates a condition. If true, it jumps to the true node/child. |
| 88 | + Code is generated using pre-order traversal of the tree considering |
| 89 | + false node as the first child. Therefore the false node is always the |
| 90 | + immediately next label. |
| 91 | + - Leaf node adds the value to the score and jumps to the next tree. |
| 92 | + """ |
| 93 | + label = f"t{tree_num}_n{node_num}" |
| 94 | + code = [] |
| 95 | + if node_num == 0: |
| 96 | + code.append(f"t{tree_num}:") |
| 97 | + |
| 98 | + if t["operation"] == "boost": |
| 99 | + code.append(node(t, label=label, next_label=f"t{tree_num+1}")) |
| 100 | + return code, 1 |
| 101 | + |
| 102 | + false_code, false_size = tree( |
| 103 | + t['else'], tree_num=tree_num, node_num=node_num+1) |
| 104 | + |
| 105 | + true_node_num = node_num+false_size+1 |
| 106 | + true_label = f"t{tree_num}_n{true_node_num}" |
| 107 | + |
| 108 | + true_code, true_size = tree( |
| 109 | + t['then'], tree_num=tree_num, node_num=true_node_num) |
| 110 | + |
| 111 | + code.append(node(t, label=label, next_label=true_label)) |
| 112 | + |
| 113 | + return code+false_code+true_code, 1+false_size+true_size |
| 114 | + |
| 115 | + |
| 116 | +def gen_header_code(features_json: list, cpp_class, filename: str): |
| 117 | + """Returns code for header declaring the inference runtime. |
| 118 | +
|
| 119 | + Declares the Example class named {cpp_class} inside relevant namespaces. |
| 120 | + The Example class contains all the features as class members. This |
| 121 | + class can be used to represent a code completion candidate. |
| 122 | + Provides `float Evaluate()` function which can be used to score the Example. |
| 123 | + """ |
| 124 | + setters = [] |
| 125 | + for f in features_json: |
| 126 | + feature = f["name"] |
| 127 | + if f["kind"] == "NUMBER": |
| 128 | + # Floats are order-encoded to integers for faster comparison. |
| 129 | + setters.append( |
| 130 | + f"void set{feature}(float V) {{ {feature} = OrderEncode(V); }}") |
| 131 | + elif f["kind"] == "ENUM": |
| 132 | + setters.append( |
| 133 | + f"void set{feature}(unsigned V) {{ {feature} = 1 << V; }}") |
| 134 | + else: |
| 135 | + raise ValueError("Unhandled feature type.", f["kind"]) |
| 136 | + |
| 137 | + # Class members represent all the features of the Example. |
| 138 | + class_members = [f"uint32_t {f['name']} = 0;" for f in features_json] |
| 139 | + |
| 140 | + nline = "\n " |
| 141 | + guard = header_guard(filename) |
| 142 | + return f"""#ifndef {guard} |
| 143 | +#define {guard} |
| 144 | +#include <cstdint> |
| 145 | +
|
| 146 | +{cpp_class.ns_begin()} |
| 147 | +class {cpp_class.name} {{ |
| 148 | +public: |
| 149 | + {nline.join(setters)} |
| 150 | +
|
| 151 | +private: |
| 152 | + {nline.join(class_members)} |
| 153 | +
|
| 154 | + // Produces an integer that sorts in the same order as F. |
| 155 | + // That is: a < b <==> orderEncode(a) < orderEncode(b). |
| 156 | + static uint32_t OrderEncode(float F); |
| 157 | + friend float Evaluate(const {cpp_class.name}&); |
| 158 | +}}; |
| 159 | +
|
| 160 | +float Evaluate(const {cpp_class.name}&); |
| 161 | +{cpp_class.ns_end()} |
| 162 | +#endif // {guard} |
| 163 | +""" |
| 164 | + |
| 165 | + |
| 166 | +def order_encode(v: float): |
| 167 | + i = struct.unpack('<I', struct.pack('<f', v))[0] |
| 168 | + TopBit = 1 << 31 |
| 169 | + # IEEE 754 floats compare like sign-magnitude integers. |
| 170 | + if (i & TopBit): # Negative float |
| 171 | + return (1 << 32) - i # low half of integers, order reversed. |
| 172 | + return TopBit + i # top half of integers |
| 173 | + |
| 174 | + |
| 175 | +def evaluate_func(forest_json: list, cpp_class: CppClass): |
| 176 | + """Generates code for `float Evaluate(const {Example}&)` function. |
| 177 | + The generated function can be used to score an Example.""" |
| 178 | + code = f"float Evaluate(const {cpp_class.name}& E) {{\n" |
| 179 | + lines = [] |
| 180 | + lines.append("float Score = 0;") |
| 181 | + tree_num = 0 |
| 182 | + for tree_json in forest_json: |
| 183 | + lines.extend(tree(tree_json, tree_num=tree_num, node_num=0)[0]) |
| 184 | + lines.append("") |
| 185 | + tree_num += 1 |
| 186 | + |
| 187 | + lines.append(f"t{len(forest_json)}: // No such tree.") |
| 188 | + lines.append("return Score;") |
| 189 | + code += " " + "\n ".join(lines) |
| 190 | + code += "\n}" |
| 191 | + return code |
| 192 | + |
| 193 | + |
| 194 | +def gen_cpp_code(forest_json: list, features_json: list, filename: str, |
| 195 | + cpp_class: CppClass): |
| 196 | + """Generates code for the .cpp file.""" |
| 197 | + # Headers |
| 198 | + # Required by OrderEncode(float F). |
| 199 | + angled_include = [ |
| 200 | + f'#include <{h}>' |
| 201 | + for h in ["cstring", "limits"] |
| 202 | + ] |
| 203 | + |
| 204 | + # Include generated header. |
| 205 | + qouted_headers = {f"{filename}.h", "llvm/ADT/bit.h"} |
| 206 | + # Headers required by ENUM features used by the model. |
| 207 | + qouted_headers |= {f["header"] |
| 208 | + for f in features_json if f["kind"] == "ENUM"} |
| 209 | + quoted_include = [f'#include "{h}"' for h in sorted(qouted_headers)] |
| 210 | + |
| 211 | + # using-decl for ENUM features. |
| 212 | + using_decls = "\n".join(f"using {feature['name']}_type = {feature['type']};" |
| 213 | + for feature in features_json |
| 214 | + if feature["kind"] == "ENUM") |
| 215 | + nl = "\n" |
| 216 | + return f"""{nl.join(angled_include)} |
| 217 | +
|
| 218 | +{nl.join(quoted_include)} |
| 219 | +
|
| 220 | +#define BIT(X) (1 << X) |
| 221 | +
|
| 222 | +{cpp_class.ns_begin()} |
| 223 | +
|
| 224 | +{using_decls} |
| 225 | +
|
| 226 | +uint32_t {cpp_class.name}::OrderEncode(float F) {{ |
| 227 | + static_assert(std::numeric_limits<float>::is_iec559, ""); |
| 228 | + constexpr uint32_t TopBit = ~(~uint32_t{{0}} >> 1); |
| 229 | +
|
| 230 | + // Get the bits of the float. Endianness is the same as for integers. |
| 231 | + uint32_t U = llvm::bit_cast<uint32_t>(F); |
| 232 | + std::memcpy(&U, &F, sizeof(U)); |
| 233 | + // IEEE 754 floats compare like sign-magnitude integers. |
| 234 | + if (U & TopBit) // Negative float. |
| 235 | + return 0 - U; // Map onto the low half of integers, order reversed. |
| 236 | + return U + TopBit; // Positive floats map onto the high half of integers. |
| 237 | +}} |
| 238 | +
|
| 239 | +{evaluate_func(forest_json, cpp_class)} |
| 240 | +{cpp_class.ns_end()} |
| 241 | +""" |
| 242 | + |
| 243 | + |
| 244 | +def main(): |
| 245 | + parser = argparse.ArgumentParser('DecisionForestCodegen') |
| 246 | + parser.add_argument('--filename', help='output file name.') |
| 247 | + parser.add_argument('--output_dir', help='output directory.') |
| 248 | + parser.add_argument('--model', help='path to model directory.') |
| 249 | + parser.add_argument( |
| 250 | + '--cpp_class', |
| 251 | + help='The name of the class (which may be a namespace-qualified) created in generated header.' |
| 252 | + ) |
| 253 | + ns = parser.parse_args() |
| 254 | + |
| 255 | + output_dir = ns.output_dir |
| 256 | + filename = ns.filename |
| 257 | + header_file = f"{output_dir}/{filename}.h" |
| 258 | + cpp_file = f"{output_dir}/{filename}.cpp" |
| 259 | + cpp_class = CppClass(cpp_class=ns.cpp_class) |
| 260 | + |
| 261 | + model_file = f"{ns.model}/forest.json" |
| 262 | + features_file = f"{ns.model}/features.json" |
| 263 | + |
| 264 | + with open(features_file) as f: |
| 265 | + features_json = json.load(f) |
| 266 | + |
| 267 | + with open(model_file) as m: |
| 268 | + forest_json = json.load(m) |
| 269 | + |
| 270 | + with open(cpp_file, 'w+t') as output_cc: |
| 271 | + output_cc.write( |
| 272 | + gen_cpp_code(forest_json=forest_json, |
| 273 | + features_json=features_json, |
| 274 | + filename=filename, |
| 275 | + cpp_class=cpp_class)) |
| 276 | + |
| 277 | + with open(header_file, 'w+t') as output_h: |
| 278 | + output_h.write(gen_header_code( |
| 279 | + features_json=features_json, cpp_class=cpp_class, filename=filename)) |
| 280 | + |
| 281 | + |
| 282 | +if __name__ == '__main__': |
| 283 | + main() |
0 commit comments