|
| 1 | +# Lint as: python3 |
| 2 | +from absl import app |
| 3 | +from absl import flags |
| 4 | +import re |
| 5 | + |
| 6 | +import yaml |
| 7 | + |
| 8 | +FLAGS = flags.FLAGS |
| 9 | + |
| 10 | +flags.DEFINE_string("def_file", None, "path to list of ops") |
| 11 | +flags.DEFINE_string("swift_out", None, "path for the generated swift file") |
| 12 | +flags.DEFINE_string("cc_output", None, "path for the generated cc file") |
| 13 | + |
| 14 | +HEADER = """// Autogenerated by codegen.py. Do not modify. |
| 15 | +""" |
| 16 | + |
| 17 | +def node_type_define(op): |
| 18 | + tensor_args = [] |
| 19 | + attr_args = [] |
| 20 | + for arg in op["args"]: |
| 21 | + if arg[1] == "Tensor": tensor_args.append(arg) |
| 22 | + else: attr_args.append(arg) |
| 23 | + def format_pretty_print(arg): |
| 24 | + return f" << \", {arg[0]}=\" << {arg[0]}_" |
| 25 | + def format_ctor_arg(arg): |
| 26 | + name, stype = arg |
| 27 | + if stype == "Tensor": return f"const Value& {name}" |
| 28 | + if stype == "Int64": return f"xla::int64 {name}" |
| 29 | + raise f"Problem: no such type: {stype}" |
| 30 | + lower_arg_i = 0 |
| 31 | + def format_lower_arg(arg): |
| 32 | + nonlocal lower_arg_i |
| 33 | + name, stype = arg |
| 34 | + if stype == "Tensor": |
| 35 | + i = lower_arg_i |
| 36 | + lower_arg_i += 1 |
| 37 | + return "loctx->GetOutputOp(operand(" + str(i) + "))" |
| 38 | + if stype == "Int64": return f"{name}_" |
| 39 | + raise f"Problem: no such type: {stype}" |
| 40 | + clone_arg_i = 0 |
| 41 | + def format_clone_arg(arg): |
| 42 | + nonlocal clone_arg_i |
| 43 | + name, stype = arg |
| 44 | + if stype == "Tensor": |
| 45 | + i = clone_arg_i |
| 46 | + clone_arg_i += 1 |
| 47 | + return "operands.at(" + str(i) + ")" |
| 48 | + if stype == "Int64": return f"{name}_" |
| 49 | + raise f"Problem: no such type: {stype}" |
| 50 | + def format_attr_define(arg): |
| 51 | + name, stype = arg |
| 52 | + if stype == "Int64": return f" xla::int64 {name}_;\n" |
| 53 | + raise f"Problem: no such type: {stype}" |
| 54 | + def format_attr_init(arg): |
| 55 | + return f",\n {arg[0]}_({arg[0]})" |
| 56 | + shape_fn = f"""{{}}\n#error no shape function for {op["op_node_name"]}\n""" |
| 57 | + def resolve_shape_fn(shape_fn): |
| 58 | + for arg in tensor_args: |
| 59 | + if arg[0] == shape_fn: return f"{arg[0]}.shape()" |
| 60 | + return f"""{shape_fn}({", ".join(arg[0] for arg in op["args"])})""" |
| 61 | + if op["shape_fn"]: |
| 62 | + shape_fn = resolve_shape_fn(op["shape_fn"]) |
| 63 | + num_outputs = 1 |
| 64 | + return f""" |
| 65 | +class {op["op_node_name"]} : public Node {{ |
| 66 | + public: |
| 67 | + {op["op_node_name"]}({", ".join(format_ctor_arg(arg) for arg in op["args"])}) |
| 68 | + : Node(ir::OpKind({op["x10_enum"]}), |
| 69 | + {{{", ".join(arg[0] for arg in tensor_args)}}}, {shape_fn}, |
| 70 | + /*num_outputs=*/{str(num_outputs)}, xla::util::MHash({", ".join(arg[0] for arg in attr_args)})){ |
| 71 | +"".join(format_attr_init(arg) for arg in attr_args) |
| 72 | +} {{}} |
| 73 | +
|
| 74 | + NodePtr Clone(OpList operands) const override {{ |
| 75 | + return MakeNode<{op["op_node_name"]}>( |
| 76 | + {", ".join(format_clone_arg(arg) for arg in op["args"])}); |
| 77 | + }} |
| 78 | +
|
| 79 | + XlaOpVector Lower(LoweringContext* loctx) const override {{ |
| 80 | + xla::XlaOp result = {op["lower_fn"]}( |
| 81 | + {", ".join(format_lower_arg(arg) for arg in op["args"])}); |
| 82 | + return ReturnOp(result, loctx); |
| 83 | + }} |
| 84 | +
|
| 85 | + std::string ToString() const override {{ |
| 86 | + std::stringstream ss; |
| 87 | + ss << Node::ToString(){"".join(format_pretty_print(arg) for arg in attr_args)}; |
| 88 | + return ss.str(); |
| 89 | + }} |
| 90 | +
|
| 91 | + private: |
| 92 | +{"".join(format_attr_define(arg) for arg in attr_args)}}}; |
| 93 | +""" |
| 94 | + |
| 95 | +def c_function_define(op): |
| 96 | + def format_arg_def(arg): |
| 97 | + name, stype = arg |
| 98 | + if stype == "Tensor": return "OpaqueXLATensor* " + name |
| 99 | + if stype == "Int64": return "int64_t " + name |
| 100 | + raise "problem unknown type: " + stype |
| 101 | + def format_arg_ref(arg): |
| 102 | + name, stype = arg |
| 103 | + if stype == "Tensor": return name + "_ir_value" |
| 104 | + for extra in op["extras"]: |
| 105 | + if extra[0] == "canonicalize" and extra[1] == name: |
| 106 | + return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndex({name}, {extra[2]}_ir_value.shape().rank())" |
| 107 | + return name |
| 108 | + def unpack_arg(arg): |
| 109 | + name, stype = arg |
| 110 | + if stype == "Tensor": return f" auto {name}_ir_value = {name}->GetIrValue();\n" |
| 111 | + return "" |
| 112 | + args = op["args"] |
| 113 | + first_tensor = args[0][0] |
| 114 | + return f""" |
| 115 | +OpaqueXLATensor* XLATensor_{op["c_name"]}({", ".join(format_arg_def(arg) for arg in op["args"])}) {{ |
| 116 | +{"".join(unpack_arg(arg) for arg in op["args"])} return new swift_xla::XLATensor({first_tensor}->CreateFrom( |
| 117 | + swift_xla::ir::MakeNode<swift_xla::ir::ops::{op["op_node_name"]}>({", ".join(format_arg_ref(arg) for arg in op["args"])}))); |
| 118 | +}} |
| 119 | +""" |
| 120 | + |
| 121 | +def snake_to_camel(name): |
| 122 | + return "".join(map(lambda x: x.capitalize(),name.split("_"))) |
| 123 | + |
| 124 | +def canonicalize_op(op): |
| 125 | + tokens = re.findall("(\w+|[\(\),:]|->)", op["def"]) |
| 126 | + op["c_name"] = tokens[0] |
| 127 | + def expect(cond): |
| 128 | + if not cond: raise ValueError(f"""invalid format: {repr(op["def"])}""") |
| 129 | + expect(tokens[1] == '(') |
| 130 | + def isWord(idx): |
| 131 | + return re.match("\w+", tokens[idx]) != None |
| 132 | + i = 2 |
| 133 | + args = [] |
| 134 | + if tokens[i] != ')': |
| 135 | + while True: |
| 136 | + expect(tokens[i + 1] == ':') |
| 137 | + expect(isWord(i) and isWord(i + 2)) |
| 138 | + args.append((tokens[i], tokens[i + 2])) |
| 139 | + i += 3 |
| 140 | + if tokens[i] == ')': break |
| 141 | + expect(tokens[i] == ',') |
| 142 | + i += 1 |
| 143 | + i += 1 |
| 144 | + |
| 145 | + op["args"] = args |
| 146 | + if "op_node_name" not in op: op["op_node_name"] = snake_to_camel(op["c_name"]) |
| 147 | + op["extras"] = [a.split() for a in op["extras"]] |
| 148 | + del op["def"] |
| 149 | + |
| 150 | +def main(argv): |
| 151 | + if len(argv) > 1: |
| 152 | + raise app.UsageError("Too many command-line arguments.") |
| 153 | + op_list = yaml.full_load(open(FLAGS.def_file).read()) |
| 154 | + for op in op_list: canonicalize_op(op) |
| 155 | + |
| 156 | + open(FLAGS.cc_output, "w+").write(HEADER + """ |
| 157 | +namespace swift_xla { |
| 158 | +namespace ir { |
| 159 | +namespace ops { |
| 160 | +namespace { |
| 161 | +""" + ("".join(node_type_define(op) for op in op_list)) + """ |
| 162 | +} // namespace |
| 163 | +} // namespace ops |
| 164 | +} // namespace ir |
| 165 | +} // namespace swift_xla |
| 166 | +""" + "".join(c_function_define(op) for op in op_list)) |
| 167 | + |
| 168 | +if __name__ == "__main__": |
| 169 | + app.run(main) |
0 commit comments