Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 72e8e03

Browse files
committed
Generate log_softmax_backward from a spec.
1 parent 897bac9 commit 72e8e03

File tree

7 files changed

+251
-7
lines changed

7 files changed

+251
-7
lines changed

Sources/CX10/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ cc_library(
1010

1111
cc_library(
1212
name = "xla_tensor_wrapper",
13-
srcs = ["xla_tensor_wrapper.cc"],
13+
srcs = [
14+
"xla_tensor_wrapper.cc",
15+
"xla_tensor_ops_wrapper.cc",
16+
"xla_tensor_ops_wrapper_generated.cc.inc",
17+
],
1418
hdrs = ["xla_tensor_wrapper.h"],
1519
deps = [
1620
":device_wrapper",
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2020 TensorFlow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "xla_tensor_wrapper.h"
16+
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
17+
#include "tensorflow/compiler/xla/xla_client/util.h"
18+
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
19+
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
20+
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
21+
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
22+
23+
#include "xla_tensor_ops_wrapper_generated.cc.inc"
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Autogenerated by codegen.py. Do not modify.
2+
3+
namespace swift_xla {
4+
namespace ir {
5+
namespace ops {
6+
namespace {
7+
8+
class LogSoftmaxBackward : public Node {
9+
public:
10+
LogSoftmaxBackward(const Value& grad_output, const Value& output, xla::int64 dim)
11+
: Node(ir::OpKind(at::aten::_log_softmax_backward_data),
12+
{grad_output, output}, grad_output.shape(),
13+
/*num_outputs=*/1, xla::util::MHash(dim)),
14+
dim_(dim) {}
15+
16+
NodePtr Clone(OpList operands) const override {
17+
return MakeNode<LogSoftmaxBackward>(
18+
operands.at(0), operands.at(1), dim_);
19+
}
20+
21+
XlaOpVector Lower(LoweringContext* loctx) const override {
22+
xla::XlaOp result = BuildLogSoftmaxGrad(
23+
loctx->GetOutputOp(operand(0)), loctx->GetOutputOp(operand(1)), dim_);
24+
return ReturnOp(result, loctx);
25+
}
26+
27+
std::string ToString() const override {
28+
std::stringstream ss;
29+
ss << Node::ToString() << ", dim=" << dim_;
30+
return ss.str();
31+
}
32+
33+
private:
34+
xla::int64 dim_;
35+
};
36+
37+
} // namespace
38+
} // namespace ops
39+
} // namespace ir
40+
} // namespace swift_xla
41+
42+
OpaqueXLATensor* XLATensor_log_softmax_backward(OpaqueXLATensor* grad_output, OpaqueXLATensor* output, int64_t dim) {
43+
auto grad_output_ir_value = grad_output->GetIrValue();
44+
auto output_ir_value = output->GetIrValue();
45+
return new swift_xla::XLATensor(grad_output->CreateFrom(
46+
swift_xla::ir::MakeNode<swift_xla::ir::ops::LogSoftmaxBackward>(grad_output_ir_value, output_ir_value, swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output_ir_value.shape().rank()))));
47+
}

Sources/CX10/xla_tensor_wrapper.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -495,12 +495,6 @@ OpaqueXLATensor* XLATensor_log1p(OpaqueXLATensor* a) {
495495
OpaqueXLATensor* XLATensor_log_softmax(OpaqueXLATensor* a, int64_t dim) {
496496
return new XLATensor(XLATensor::log_softmax(*a, dim, absl::nullopt));
497497
}
498-
OpaqueXLATensor* XLATensor_log_softmax_backward(OpaqueXLATensor* grad_output,
499-
OpaqueXLATensor* output,
500-
int64_t dim) {
501-
return new XLATensor(
502-
XLATensor::log_softmax_backward(*grad_output, *output, dim));
503-
}
504498
OpaqueXLATensor* XLATensor_logical_cast(OpaqueXLATensor* input,
505499
enum XLATensorScalarType dest_type) {
506500
return new XLATensor(
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Lint as: python3
2+
from absl import app
3+
from absl import flags
4+
import re
5+
6+
import yaml
7+
8+
FLAGS = flags.FLAGS
9+
10+
flags.DEFINE_string("def_file", None, "path to list of ops")
11+
flags.DEFINE_string("swift_out", None, "path for the generated swift file")
12+
flags.DEFINE_string("cc_output", None, "path for the generated cc file")
13+
14+
HEADER = """// Autogenerated by codegen.py. Do not modify.
15+
"""
16+
17+
def node_type_define(op):
18+
tensor_args = []
19+
attr_args = []
20+
for arg in op["args"]:
21+
if arg[1] == "Tensor": tensor_args.append(arg)
22+
else: attr_args.append(arg)
23+
def format_pretty_print(arg):
24+
return f" << \", {arg[0]}=\" << {arg[0]}_"
25+
def format_ctor_arg(arg):
26+
name, stype = arg
27+
if stype == "Tensor": return f"const Value& {name}"
28+
if stype == "Int64": return f"xla::int64 {name}"
29+
raise f"Problem: no such type: {stype}"
30+
lower_arg_i = 0
31+
def format_lower_arg(arg):
32+
nonlocal lower_arg_i
33+
name, stype = arg
34+
if stype == "Tensor":
35+
i = lower_arg_i
36+
lower_arg_i += 1
37+
return "loctx->GetOutputOp(operand(" + str(i) + "))"
38+
if stype == "Int64": return f"{name}_"
39+
raise f"Problem: no such type: {stype}"
40+
clone_arg_i = 0
41+
def format_clone_arg(arg):
42+
nonlocal clone_arg_i
43+
name, stype = arg
44+
if stype == "Tensor":
45+
i = clone_arg_i
46+
clone_arg_i += 1
47+
return "operands.at(" + str(i) + ")"
48+
if stype == "Int64": return f"{name}_"
49+
raise f"Problem: no such type: {stype}"
50+
def format_attr_define(arg):
51+
name, stype = arg
52+
if stype == "Int64": return f" xla::int64 {name}_;\n"
53+
raise f"Problem: no such type: {stype}"
54+
def format_attr_init(arg):
55+
return f",\n {arg[0]}_({arg[0]})"
56+
shape_fn = f"""{{}}\n#error no shape function for {op["op_node_name"]}\n"""
57+
def resolve_shape_fn(shape_fn):
58+
for arg in tensor_args:
59+
if arg[0] == shape_fn: return f"{arg[0]}.shape()"
60+
return f"""{shape_fn}({", ".join(arg[0] for arg in op["args"])})"""
61+
if op["shape_fn"]:
62+
shape_fn = resolve_shape_fn(op["shape_fn"])
63+
num_outputs = 1
64+
return f"""
65+
class {op["op_node_name"]} : public Node {{
66+
public:
67+
{op["op_node_name"]}({", ".join(format_ctor_arg(arg) for arg in op["args"])})
68+
: Node(ir::OpKind({op["x10_enum"]}),
69+
{{{", ".join(arg[0] for arg in tensor_args)}}}, {shape_fn},
70+
/*num_outputs=*/{str(num_outputs)}, xla::util::MHash({", ".join(arg[0] for arg in attr_args)})){
71+
"".join(format_attr_init(arg) for arg in attr_args)
72+
} {{}}
73+
74+
NodePtr Clone(OpList operands) const override {{
75+
return MakeNode<{op["op_node_name"]}>(
76+
{", ".join(format_clone_arg(arg) for arg in op["args"])});
77+
}}
78+
79+
XlaOpVector Lower(LoweringContext* loctx) const override {{
80+
xla::XlaOp result = {op["lower_fn"]}(
81+
{", ".join(format_lower_arg(arg) for arg in op["args"])});
82+
return ReturnOp(result, loctx);
83+
}}
84+
85+
std::string ToString() const override {{
86+
std::stringstream ss;
87+
ss << Node::ToString(){"".join(format_pretty_print(arg) for arg in attr_args)};
88+
return ss.str();
89+
}}
90+
91+
private:
92+
{"".join(format_attr_define(arg) for arg in attr_args)}}};
93+
"""
94+
95+
def c_function_define(op):
96+
def format_arg_def(arg):
97+
name, stype = arg
98+
if stype == "Tensor": return "OpaqueXLATensor* " + name
99+
if stype == "Int64": return "int64_t " + name
100+
raise "problem unknown type: " + stype
101+
def format_arg_ref(arg):
102+
name, stype = arg
103+
if stype == "Tensor": return name + "_ir_value"
104+
for extra in op["extras"]:
105+
if extra[0] == "canonicalize" and extra[1] == name:
106+
return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndex({name}, {extra[2]}_ir_value.shape().rank())"
107+
return name
108+
def unpack_arg(arg):
109+
name, stype = arg
110+
if stype == "Tensor": return f" auto {name}_ir_value = {name}->GetIrValue();\n"
111+
return ""
112+
args = op["args"]
113+
first_tensor = args[0][0]
114+
return f"""
115+
OpaqueXLATensor* XLATensor_{op["c_name"]}({", ".join(format_arg_def(arg) for arg in op["args"])}) {{
116+
{"".join(unpack_arg(arg) for arg in op["args"])} return new swift_xla::XLATensor({first_tensor}->CreateFrom(
117+
swift_xla::ir::MakeNode<swift_xla::ir::ops::{op["op_node_name"]}>({", ".join(format_arg_ref(arg) for arg in op["args"])})));
118+
}}
119+
"""
120+
121+
def snake_to_camel(name):
122+
return "".join(map(lambda x: x.capitalize(),name.split("_")))
123+
124+
def canonicalize_op(op):
125+
tokens = re.findall("(\w+|[\(\),:]|->)", op["def"])
126+
op["c_name"] = tokens[0]
127+
def expect(cond):
128+
if not cond: raise ValueError(f"""invalid format: {repr(op["def"])}""")
129+
expect(tokens[1] == '(')
130+
def isWord(idx):
131+
return re.match("\w+", tokens[idx]) != None
132+
i = 2
133+
args = []
134+
if tokens[i] != ')':
135+
while True:
136+
expect(tokens[i + 1] == ':')
137+
expect(isWord(i) and isWord(i + 2))
138+
args.append((tokens[i], tokens[i + 2]))
139+
i += 3
140+
if tokens[i] == ')': break
141+
expect(tokens[i] == ',')
142+
i += 1
143+
i += 1
144+
145+
op["args"] = args
146+
if "op_node_name" not in op: op["op_node_name"] = snake_to_camel(op["c_name"])
147+
op["extras"] = [a.split() for a in op["extras"]]
148+
del op["def"]
149+
150+
def main(argv):
151+
if len(argv) > 1:
152+
raise app.UsageError("Too many command-line arguments.")
153+
op_list = yaml.full_load(open(FLAGS.def_file).read())
154+
for op in op_list: canonicalize_op(op)
155+
156+
open(FLAGS.cc_output, "w+").write(HEADER + """
157+
namespace swift_xla {
158+
namespace ir {
159+
namespace ops {
160+
namespace {
161+
""" + ("".join(node_type_define(op) for op in op_list)) + """
162+
} // namespace
163+
} // namespace ops
164+
} // namespace ir
165+
} // namespace swift_xla
166+
""" + "".join(c_function_define(op) for op in op_list))
167+
168+
if __name__ == "__main__":
169+
app.run(main)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- def: "log_softmax_backward(grad_output: Tensor, output: Tensor, dim: Int64) -> Tensor"
2+
extras: ["canonicalize dim grad_output"]
3+
x10_enum: at::aten::_log_softmax_backward_data
4+
shape_fn: grad_output
5+
lower_fn: BuildLogSoftmaxGrad

Sources/x10/xla_tensor/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,7 @@ class XLATensor {
14241424

14251425
XLATensor CopyTensorToDevice(const Device& device);
14261426

1427+
public:
14271428
// Create a new XLA tensor with the same metadata of the input tensor (with
14281429
// possible overrides), and the new IR value.
14291430
XLATensor CreateFrom(ir::Value ir_value) const;
@@ -1435,6 +1436,7 @@ class XLATensor {
14351436
c10::optional<at::ScalarType> logical_element_type_opt) const;
14361437
XLATensor CreateFrom(ir::Value ir_value, const Device& device,
14371438
at::ScalarType logical_element_type) const;
1439+
private:
14381440

14391441
// We build an XLA graph accumulating XLA operations, but at a given point we
14401442
// need to force a rendering, otherwise the graph can grow without control.

0 commit comments

Comments
 (0)