Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Generate cumsum and cumprod from a spec. #1071

Merged
merged 1 commit into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions Sources/CX10/xla_tensor_ops_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,72 @@
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"

namespace at {
xla::hash_t Hash(const c10::optional<at::ScalarType>& dtype) {
return xla::util::Hash(swift_xla::OptionalOr<int>(dtype, -1));
}
}
namespace swift_xla {
void OpFieldToString(std::ostream& stream, const char* field_name, const c10::optional<at::ScalarType>& dtype) {
if (dtype) stream << ", " << field_name << "=" << *dtype;
}
void OpFieldToString(std::ostream& stream, const char* field_name, bool value) {
stream << ", " << field_name << "=" << value;
}
void OpFieldToString(std::ostream& stream, const char* field_name, xla::int64 value) {
stream << ", " << field_name << "=" << value;
}
} // namespace swift_xla

namespace swift_xla {
namespace ir {
namespace ops {
namespace {

xla::XlaOp LowerCumSum(xla::XlaOp input, xla::int64 dim,
c10::optional<at::ScalarType> dtype, bool exclusive,
bool reverse) {
xla::XlaOp casted_input = CastToScalarType(input, dtype);
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(casted_input);
xla::XlaOp init = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), casted_input.builder());
xla::XlaComputation reducer =
XlaHelpers::CreateAddComputation(input_shape.element_type());
return BuildCumulativeComputation(casted_input, dim, reducer, init, exclusive,
reverse);
}

xla::XlaOp LowerCumProd(xla::XlaOp input, xla::int64 dim,
c10::optional<at::ScalarType> dtype, bool exclusive,
bool reverse) {
xla::XlaOp casted_input = CastToScalarType(input, dtype);
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(casted_input);
xla::XlaOp init =
xla::One(casted_input.builder(), input_shape.element_type());
xla::XlaComputation reducer =
XlaHelpers::CreateMulComputation(input_shape.element_type());
return BuildCumulativeComputation(casted_input, dim, reducer, init, exclusive,
reverse);
}

xla::Shape CumOpShapeFn(const Value& input, xla::int64 dim,
c10::optional<at::ScalarType> dtype, bool exclusive,
bool reverse) {
if (dtype) {
return xla::ShapeUtil::ChangeElementType(
input.shape(), MakeXlaPrimitiveType(*dtype, /*device=*/nullptr));
}
return input.shape();
}

} // namespace
} // namespace ops
} // namespace ir
} // namespace swift_xla

#include "xla_tensor_ops_wrapper_generated.cc.inc"
95 changes: 93 additions & 2 deletions Sources/CX10/xla_tensor_ops_wrapper_generated.cc.inc
Original file line number Diff line number Diff line change
@@ -1,10 +1,88 @@
// Autogenerated by codegen.py. Do not modify.

namespace swift_xla {
namespace ir {
namespace ops {
namespace {

class Cumprod : public Node {
public:
Cumprod(const Value& input, xla::int64 dim, c10::optional<at::ScalarType> dtype, bool exclusive, bool reverse)
: Node(ir::OpKind(at::aten::cumprod),
{input}, CumOpShapeFn(input, dim, dtype, exclusive, reverse),
/*num_outputs=*/1, xla::util::MHash(dim, dtype, exclusive, reverse)),
dim_(dim),
dtype_(dtype),
exclusive_(exclusive),
reverse_(reverse) {}

NodePtr Clone(OpList operands) const override {
return MakeNode<Cumprod>(
operands.at(0), dim_, dtype_, exclusive_, reverse_);
}

XlaOpVector Lower(LoweringContext* loctx) const override {
xla::XlaOp result = LowerCumProd(
loctx->GetOutputOp(operand(0)), dim_, dtype_, exclusive_, reverse_);
return ReturnOp(result, loctx);
}

std::string ToString() const override {
std::stringstream ss;
ss << Node::ToString();
OpFieldToString(ss, "dim", dim_);
OpFieldToString(ss, "dtype", dtype_);
OpFieldToString(ss, "exclusive", exclusive_);
OpFieldToString(ss, "reverse", reverse_);
return ss.str();
}

private:
xla::int64 dim_;
c10::optional<at::ScalarType> dtype_;
bool exclusive_;
bool reverse_;
};

class Cumsum : public Node {
public:
Cumsum(const Value& input, xla::int64 dim, c10::optional<at::ScalarType> dtype, bool exclusive, bool reverse)
: Node(ir::OpKind(at::aten::cumsum),
{input}, CumOpShapeFn(input, dim, dtype, exclusive, reverse),
/*num_outputs=*/1, xla::util::MHash(dim, dtype, exclusive, reverse)),
dim_(dim),
dtype_(dtype),
exclusive_(exclusive),
reverse_(reverse) {}

NodePtr Clone(OpList operands) const override {
return MakeNode<Cumsum>(
operands.at(0), dim_, dtype_, exclusive_, reverse_);
}

XlaOpVector Lower(LoweringContext* loctx) const override {
xla::XlaOp result = LowerCumSum(
loctx->GetOutputOp(operand(0)), dim_, dtype_, exclusive_, reverse_);
return ReturnOp(result, loctx);
}

std::string ToString() const override {
std::stringstream ss;
ss << Node::ToString();
OpFieldToString(ss, "dim", dim_);
OpFieldToString(ss, "dtype", dtype_);
OpFieldToString(ss, "exclusive", exclusive_);
OpFieldToString(ss, "reverse", reverse_);
return ss.str();
}

private:
xla::int64 dim_;
c10::optional<at::ScalarType> dtype_;
bool exclusive_;
bool reverse_;
};

class LogSoftmaxBackward : public Node {
public:
LogSoftmaxBackward(const Value& grad_output, const Value& output, xla::int64 dim)
Expand All @@ -26,7 +104,8 @@ class LogSoftmaxBackward : public Node {

std::string ToString() const override {
std::stringstream ss;
ss << Node::ToString() << ", dim=" << dim_;
ss << Node::ToString();
OpFieldToString(ss, "dim", dim_);
return ss.str();
}

Expand All @@ -39,6 +118,18 @@ class LogSoftmaxBackward : public Node {
} // namespace ir
} // namespace swift_xla

OpaqueXLATensor* XLATensor_cumprod(OpaqueXLATensor* input, int64_t dim, Optional_XLAScalarType dtype, bool exclusive, bool reverse) {
auto input_ir_value = input->GetIrValue();
return new swift_xla::XLATensor(input->CreateFrom(
swift_xla::ir::MakeNode<swift_xla::ir::ops::Cumprod>(input_ir_value, swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim, input_ir_value.shape().rank()), dtype.value(), exclusive, reverse)));
}

OpaqueXLATensor* XLATensor_cumsum(OpaqueXLATensor* input, int64_t dim, Optional_XLAScalarType dtype, bool exclusive, bool reverse) {
auto input_ir_value = input->GetIrValue();
return new swift_xla::XLATensor(input->CreateFrom(
swift_xla::ir::MakeNode<swift_xla::ir::ops::Cumsum>(input_ir_value, swift_xla::XlaHelpers::GetCanonicalDimensionIndex(dim, input_ir_value.shape().rank()), dtype.value(), exclusive, reverse)));
}

OpaqueXLATensor* XLATensor_log_softmax_backward(OpaqueXLATensor* grad_output, OpaqueXLATensor* output, int64_t dim) {
auto grad_output_ir_value = grad_output->GetIrValue();
auto output_ir_value = output->GetIrValue();
Expand Down
12 changes: 0 additions & 12 deletions Sources/CX10/xla_tensor_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,18 +294,6 @@ OpaqueXLATensor* XLATensor_acos(OpaqueXLATensor* a) {
OpaqueXLATensor* XLATensor_acosh(OpaqueXLATensor* a) {
return new XLATensor(XLATensor::acosh(*a));
}
OpaqueXLATensor* XLATensor_cumprod(OpaqueXLATensor* a, int64_t dim,
Optional_XLAScalarType dtype, bool exclusive,
bool reverse) {
return new XLATensor(
XLATensor::cumprod(*a, dim, dtype.value(), exclusive, reverse));
}
OpaqueXLATensor* XLATensor_cumsum(OpaqueXLATensor* a, int64_t dim,
Optional_XLAScalarType dtype, bool exclusive,
bool reverse) {
return new XLATensor(
XLATensor::cumsum(*a, dim, dtype.value(), exclusive, reverse));
}
OpaqueXLATensor* XLATensor_add(OpaqueXLATensor* a, OpaqueXLATensor* b) {
return new XLATensor(XLATensor::add(*a, *b));
}
Expand Down
28 changes: 16 additions & 12 deletions Sources/x10/swift_bindings/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
FLAGS = flags.FLAGS

flags.DEFINE_string("def_file", None, "path to list of ops")
flags.DEFINE_string("swift_out", None, "path for the generated swift file")
flags.DEFINE_string("cc_output", None, "path for the generated cc file")

HEADER = """// Autogenerated by codegen.py. Do not modify.
Expand All @@ -21,12 +20,14 @@ def node_type_define(op):
if arg[1] == "Tensor": tensor_args.append(arg)
else: attr_args.append(arg)
def format_pretty_print(arg):
return f" << \", {arg[0]}=\" << {arg[0]}_"
return f" OpFieldToString(ss, \"{arg[0]}\", {arg[0]}_);\n"
def format_ctor_arg(arg):
name, stype = arg
if stype == "Tensor": return f"const Value& {name}"
if stype == "Int64": return f"xla::int64 {name}"
raise f"Problem: no such type: {stype}"
if stype == "Bool": return f"bool {name}"
if stype == "ScalarType?": return f"c10::optional<at::ScalarType> {name}"
raise ValueError(f"Problem: no such type: {stype}")
lower_arg_i = 0
def format_lower_arg(arg):
nonlocal lower_arg_i
Expand All @@ -35,8 +36,7 @@ def format_lower_arg(arg):
i = lower_arg_i
lower_arg_i += 1
return "loctx->GetOutputOp(operand(" + str(i) + "))"
if stype == "Int64": return f"{name}_"
raise f"Problem: no such type: {stype}"
return f"{name}_"
clone_arg_i = 0
def format_clone_arg(arg):
nonlocal clone_arg_i
Expand All @@ -45,12 +45,13 @@ def format_clone_arg(arg):
i = clone_arg_i
clone_arg_i += 1
return "operands.at(" + str(i) + ")"
if stype == "Int64": return f"{name}_"
raise f"Problem: no such type: {stype}"
return f"{name}_"
def format_attr_define(arg):
name, stype = arg
if stype == "Int64": return f" xla::int64 {name}_;\n"
raise f"Problem: no such type: {stype}"
if stype == "Bool": return f" bool {name}_;\n"
if stype == "ScalarType?": return f" c10::optional<at::ScalarType> {name}_;\n"
raise ValueError(f"Problem: no such type: {stype}")
def format_attr_init(arg):
return f",\n {arg[0]}_({arg[0]})"
shape_fn = f"""{{}}\n#error no shape function for {op["op_node_name"]}\n"""
Expand Down Expand Up @@ -84,8 +85,8 @@ class {op["op_node_name"]} : public Node {{

std::string ToString() const override {{
std::stringstream ss;
ss << Node::ToString(){"".join(format_pretty_print(arg) for arg in attr_args)};
return ss.str();
ss << Node::ToString();
{"".join(format_pretty_print(arg) for arg in attr_args)} return ss.str();
}}

private:
Expand All @@ -97,13 +98,16 @@ def format_arg_def(arg):
name, stype = arg
if stype == "Tensor": return "OpaqueXLATensor* " + name
if stype == "Int64": return "int64_t " + name
raise "problem unknown type: " + stype
if stype == "Bool": return f"bool {name}"
if stype == "ScalarType?": return f"Optional_XLAScalarType {name}"
raise ValueError("problem unknown type: " + stype)
def format_arg_ref(arg):
name, stype = arg
if stype == "Tensor": return name + "_ir_value"
for extra in op["extras"]:
if extra[0] == "canonicalize" and extra[1] == name:
return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndex({name}, {extra[2]}_ir_value.shape().rank())"
if stype == "ScalarType?": return f"{name}.value()"
return name
def unpack_arg(arg):
name, stype = arg
Expand All @@ -122,7 +126,7 @@ def snake_to_camel(name):
return "".join(map(lambda x: x.capitalize(),name.split("_")))

def canonicalize_op(op):
tokens = re.findall("(\w+|[\(\),:]|->)", op["def"])
tokens = re.findall("(\w+\??|[\(\),:]|->)", op["def"])
op["c_name"] = tokens[0]
def expect(cond):
if not cond: raise ValueError(f"""invalid format: {repr(op["def"])}""")
Expand Down
10 changes: 10 additions & 0 deletions Sources/x10/swift_bindings/ops_list.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
- def: "cumprod(input: Tensor, dim: Int64, dtype: ScalarType?, exclusive: Bool, reverse: Bool) -> Tensor"
extras: ["canonicalize dim input"]
x10_enum: at::aten::cumprod
shape_fn: CumOpShapeFn
lower_fn: LowerCumProd
- def: "cumsum(input: Tensor, dim: Int64, dtype: ScalarType?, exclusive: Bool, reverse: Bool) -> Tensor"
extras: ["canonicalize dim input"]
x10_enum: at::aten::cumsum
shape_fn: CumOpShapeFn
lower_fn: LowerCumSum
- def: "log_softmax_backward(grad_output: Tensor, output: Tensor, dim: Int64) -> Tensor"
extras: ["canonicalize dim grad_output"]
x10_enum: at::aten::_log_softmax_backward_data
Expand Down