Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Generate log_softmax_backward from a spec. #1070

Merged
merged 1 commit into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Sources/CX10/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ cc_library(

cc_library(
name = "xla_tensor_wrapper",
srcs = ["xla_tensor_wrapper.cc"],
srcs = [
"xla_tensor_wrapper.cc",
"xla_tensor_ops_wrapper.cc",
"xla_tensor_ops_wrapper_generated.cc.inc",
],
hdrs = ["xla_tensor_wrapper.h"],
deps = [
":device_wrapper",
Expand Down
30 changes: 30 additions & 0 deletions Sources/CX10/xla_tensor_ops_wrapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2020 TensorFlow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(_WIN32)
#define XLA_API __declspec(dllexport)
#else
#define XLA_API __attribute__((__visibility__("default")))
#endif

#include "xla_tensor_wrapper.h"

#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"

#include "xla_tensor_ops_wrapper_generated.cc.inc"
47 changes: 47 additions & 0 deletions Sources/CX10/xla_tensor_ops_wrapper_generated.cc.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Autogenerated by codegen.py. Do not modify.

namespace swift_xla {
namespace ir {
namespace ops {
namespace {

class LogSoftmaxBackward : public Node {
public:
LogSoftmaxBackward(const Value& grad_output, const Value& output, xla::int64 dim)
: Node(ir::OpKind(at::aten::_log_softmax_backward_data),
{grad_output, output}, grad_output.shape(),
/*num_outputs=*/1, xla::util::MHash(dim)),
dim_(dim) {}

NodePtr Clone(OpList operands) const override {
return MakeNode<LogSoftmaxBackward>(
operands.at(0), operands.at(1), dim_);
}

XlaOpVector Lower(LoweringContext* loctx) const override {
xla::XlaOp result = BuildLogSoftmaxGrad(
loctx->GetOutputOp(operand(0)), loctx->GetOutputOp(operand(1)), dim_);
return ReturnOp(result, loctx);
}

std::string ToString() const override {
std::stringstream ss;
ss << Node::ToString() << ", dim=" << dim_;
return ss.str();
}

private:
xla::int64 dim_;
};

} // namespace
} // namespace ops
} // namespace ir
} // namespace swift_xla

OpaqueXLATensor* XLATensor_log_softmax_backward(OpaqueXLATensor* grad_output, OpaqueXLATensor* output, int64_t dim) {
auto grad_output_ir_value = grad_output->GetIrValue();
auto output_ir_value = output->GetIrValue();
return new swift_xla::XLATensor(grad_output->CreateFrom(
swift_xla::ir::MakeNode<swift_xla::ir::ops::LogSoftmaxBackward>(grad_output_ir_value, output_ir_value, swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output_ir_value.shape().rank()))));
}
6 changes: 0 additions & 6 deletions Sources/CX10/xla_tensor_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,12 +495,6 @@ OpaqueXLATensor* XLATensor_log1p(OpaqueXLATensor* a) {
OpaqueXLATensor* XLATensor_log_softmax(OpaqueXLATensor* a, int64_t dim) {
return new XLATensor(XLATensor::log_softmax(*a, dim, absl::nullopt));
}
OpaqueXLATensor* XLATensor_log_softmax_backward(OpaqueXLATensor* grad_output,
OpaqueXLATensor* output,
int64_t dim) {
return new XLATensor(
XLATensor::log_softmax_backward(*grad_output, *output, dim));
}
OpaqueXLATensor* XLATensor_logical_cast(OpaqueXLATensor* input,
enum XLATensorScalarType dest_type) {
return new XLATensor(
Expand Down
169 changes: 169 additions & 0 deletions Sources/x10/swift_bindings/generate_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Lint as: python3
from absl import app
from absl import flags
import re

import yaml

FLAGS = flags.FLAGS

flags.DEFINE_string("def_file", None, "path to list of ops")
flags.DEFINE_string("swift_out", None, "path for the generated swift file")
flags.DEFINE_string("cc_output", None, "path for the generated cc file")

HEADER = """// Autogenerated by codegen.py. Do not modify.
"""

def node_type_define(op):
tensor_args = []
attr_args = []
for arg in op["args"]:
if arg[1] == "Tensor": tensor_args.append(arg)
else: attr_args.append(arg)
def format_pretty_print(arg):
return f" << \", {arg[0]}=\" << {arg[0]}_"
def format_ctor_arg(arg):
name, stype = arg
if stype == "Tensor": return f"const Value& {name}"
if stype == "Int64": return f"xla::int64 {name}"
raise f"Problem: no such type: {stype}"
lower_arg_i = 0
def format_lower_arg(arg):
nonlocal lower_arg_i
name, stype = arg
if stype == "Tensor":
i = lower_arg_i
lower_arg_i += 1
return "loctx->GetOutputOp(operand(" + str(i) + "))"
if stype == "Int64": return f"{name}_"
raise f"Problem: no such type: {stype}"
clone_arg_i = 0
def format_clone_arg(arg):
nonlocal clone_arg_i
name, stype = arg
if stype == "Tensor":
i = clone_arg_i
clone_arg_i += 1
return "operands.at(" + str(i) + ")"
if stype == "Int64": return f"{name}_"
raise f"Problem: no such type: {stype}"
def format_attr_define(arg):
name, stype = arg
if stype == "Int64": return f" xla::int64 {name}_;\n"
raise f"Problem: no such type: {stype}"
def format_attr_init(arg):
return f",\n {arg[0]}_({arg[0]})"
shape_fn = f"""{{}}\n#error no shape function for {op["op_node_name"]}\n"""
def resolve_shape_fn(shape_fn):
for arg in tensor_args:
if arg[0] == shape_fn: return f"{arg[0]}.shape()"
return f"""{shape_fn}({", ".join(arg[0] for arg in op["args"])})"""
if op["shape_fn"]:
shape_fn = resolve_shape_fn(op["shape_fn"])
num_outputs = 1
return f"""
class {op["op_node_name"]} : public Node {{
public:
{op["op_node_name"]}({", ".join(format_ctor_arg(arg) for arg in op["args"])})
: Node(ir::OpKind({op["x10_enum"]}),
{{{", ".join(arg[0] for arg in tensor_args)}}}, {shape_fn},
/*num_outputs=*/{str(num_outputs)}, xla::util::MHash({", ".join(arg[0] for arg in attr_args)})){
"".join(format_attr_init(arg) for arg in attr_args)
} {{}}

NodePtr Clone(OpList operands) const override {{
return MakeNode<{op["op_node_name"]}>(
{", ".join(format_clone_arg(arg) for arg in op["args"])});
}}

XlaOpVector Lower(LoweringContext* loctx) const override {{
xla::XlaOp result = {op["lower_fn"]}(
{", ".join(format_lower_arg(arg) for arg in op["args"])});
return ReturnOp(result, loctx);
}}

std::string ToString() const override {{
std::stringstream ss;
ss << Node::ToString(){"".join(format_pretty_print(arg) for arg in attr_args)};
return ss.str();
}}

private:
{"".join(format_attr_define(arg) for arg in attr_args)}}};
"""

def c_function_define(op):
def format_arg_def(arg):
name, stype = arg
if stype == "Tensor": return "OpaqueXLATensor* " + name
if stype == "Int64": return "int64_t " + name
raise "problem unknown type: " + stype
def format_arg_ref(arg):
name, stype = arg
if stype == "Tensor": return name + "_ir_value"
for extra in op["extras"]:
if extra[0] == "canonicalize" and extra[1] == name:
return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndex({name}, {extra[2]}_ir_value.shape().rank())"
return name
def unpack_arg(arg):
name, stype = arg
if stype == "Tensor": return f" auto {name}_ir_value = {name}->GetIrValue();\n"
return ""
args = op["args"]
first_tensor = args[0][0]
return f"""
OpaqueXLATensor* XLATensor_{op["c_name"]}({", ".join(format_arg_def(arg) for arg in op["args"])}) {{
{"".join(unpack_arg(arg) for arg in op["args"])} return new swift_xla::XLATensor({first_tensor}->CreateFrom(
swift_xla::ir::MakeNode<swift_xla::ir::ops::{op["op_node_name"]}>({", ".join(format_arg_ref(arg) for arg in op["args"])})));
}}
"""

def snake_to_camel(name):
return "".join(map(lambda x: x.capitalize(),name.split("_")))

def canonicalize_op(op):
tokens = re.findall("(\w+|[\(\),:]|->)", op["def"])
op["c_name"] = tokens[0]
def expect(cond):
if not cond: raise ValueError(f"""invalid format: {repr(op["def"])}""")
expect(tokens[1] == '(')
def isWord(idx):
return re.match("\w+", tokens[idx]) != None
i = 2
args = []
if tokens[i] != ')':
while True:
expect(tokens[i + 1] == ':')
expect(isWord(i) and isWord(i + 2))
args.append((tokens[i], tokens[i + 2]))
i += 3
if tokens[i] == ')': break
expect(tokens[i] == ',')
i += 1
i += 1

op["args"] = args
if "op_node_name" not in op: op["op_node_name"] = snake_to_camel(op["c_name"])
op["extras"] = [a.split() for a in op["extras"]]
del op["def"]

def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
op_list = yaml.full_load(open(FLAGS.def_file).read())
for op in op_list: canonicalize_op(op)

open(FLAGS.cc_output, "w+").write(HEADER + """
namespace swift_xla {
namespace ir {
namespace ops {
namespace {
""" + ("".join(node_type_define(op) for op in op_list)) + """
} // namespace
} // namespace ops
} // namespace ir
} // namespace swift_xla
""" + "".join(c_function_define(op) for op in op_list))

if __name__ == "__main__":
app.run(main)
5 changes: 5 additions & 0 deletions Sources/x10/swift_bindings/ops_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- def: "log_softmax_backward(grad_output: Tensor, output: Tensor, dim: Int64) -> Tensor"
extras: ["canonicalize dim grad_output"]
x10_enum: at::aten::_log_softmax_backward_data
shape_fn: grad_output
lower_fn: BuildLogSoftmaxGrad
2 changes: 2 additions & 0 deletions Sources/x10/xla_tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,7 @@ class XLATensor {

XLATensor CopyTensorToDevice(const Device& device);

public:
// Create a new XLA tensor with the same metadata of the input tensor (with
// possible overrides), and the new IR value.
XLATensor CreateFrom(ir::Value ir_value) const;
Expand All @@ -1435,6 +1436,7 @@ class XLATensor {
c10::optional<at::ScalarType> logical_element_type_opt) const;
XLATensor CreateFrom(ir::Value ir_value, const Device& device,
at::ScalarType logical_element_type) const;
private:

// We build an XLA graph accumulating XLA operations, but at a given point we
// need to force a rendering, otherwise the graph can grow without control.
Expand Down