Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Tensor-level annotations #1064

Merged
merged 17 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Sources/CX10/xla_tensor_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ OpaqueXLATensor* XLATensor_all(OpaqueXLATensor* input, Int64ArrayRef dimensions,
XlaHelpers::I64List(dimensions.slice()),
keep_reduced_dimensions));
}
OpaqueXLATensor* XLATensor_annotate(OpaqueXLATensor* a,
const char* annotation) {
return new XLATensor(XLATensor::annotate(*a, std::string(annotation)));
}
OpaqueXLATensor* XLATensor_any(OpaqueXLATensor* input, Int64ArrayRef dimensions,
bool keep_reduced_dimensions) {
return new XLATensor(XLATensor::any(*input,
Expand Down Expand Up @@ -441,6 +445,11 @@ OpaqueXLATensor* XLATensor_full(Int64ArrayRef size, XLAScalar value,
OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y) {
return new XLATensor(XLATensor::ge(*x, *y));
}
OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a) {
std::string ir_dag_text =
swift_xla::ir::DumpUtil::GetAnnotations({a->GetIrValue().node.get()});
return new std::string(ir_dag_text);
}
OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y) {
return new XLATensor(XLATensor::gt(*x, *y));
}
Expand Down
4 changes: 3 additions & 1 deletion Sources/CX10/xla_tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ XLA_API OpaqueXLATensor* XLATensor_add(OpaqueXLATensor* a, OpaqueXLATensor* b);
XLA_API OpaqueXLATensor* XLATensor_all(OpaqueXLATensor* input,
Int64ArrayRef dimensions,
bool keep_reduced_dimensions);
XLA_API OpaqueXLATensor* XLATensor_annotate(OpaqueXLATensor* a, const char*);
XLA_API OpaqueXLATensor* XLATensor_any(OpaqueXLATensor* input,
Int64ArrayRef dimensions,
bool keep_reduced_dimensions);
Expand Down Expand Up @@ -284,10 +285,12 @@ XLA_API OpaqueXLATensor*
XLATensor_full(Int64ArrayRef size, XLAScalar value, const struct CDevice device,
enum XLATensorScalarType type);
XLA_API OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y);
XLA_API OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y);
XLA_API OpaqueXLATensor* XLATensor_index(OpaqueXLATensor* input,
OpaqueXLATensorArrayRef indices,
int64_t start_dim);
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_is_finite(OpaqueXLATensor* input);
XLA_API OpaqueXLATensor* XLATensor_is_inf(OpaqueXLATensor* input);
XLA_API OpaqueXLATensor* XLATensor_is_nan(OpaqueXLATensor* input);
Expand Down Expand Up @@ -367,7 +370,6 @@ XLA_API OpaqueXLATensor* XLATensor_sqrt(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_squeeze(OpaqueXLATensor* a, int64_t dim);
XLA_API OpaqueXLATensor*
XLATensor_stack(OpaqueXLATensorArrayRef tensors, int64_t dim);
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_sub(OpaqueXLATensor* a, OpaqueXLATensor* b);
XLA_API OpaqueXLATensor* XLATensor_sum(OpaqueXLATensor* a, Int64ArrayRef dims,
bool keep_reduced_dimensions,
Expand Down
65 changes: 63 additions & 2 deletions Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import _Differentiation
import CTensorFlow
import _Differentiation

infix operator .==: ComparisonPrecedence
infix operator .!=: ComparisonPrecedence
Expand All @@ -24,7 +24,7 @@ public protocol AnyTensor {
var _tensorFlowDataType: TensorDataType { get }
}

/// A multidimensional array of elements that is a generalization of vectors and matrices to
/// A multidimensional array of elements that is a generalization of vectors and matrices to
/// potentially higher dimensions.
///
/// The generic parameter `Scalar` describes the type of scalars in the tensor (such as `Int32`,
Expand All @@ -41,6 +41,67 @@ public struct Tensor<Scalar: TensorFlowScalar> {
}
}

public protocol TensorProtocol {
associatedtype Scalar: TensorFlowScalar
init(repeating repeatedValue: Scalar, shape: TensorShape, on device: Device)
var annotations: String { get }
var shape: TensorShape { get }
var summary: String { get }
}

public protocol DifferentiableTensorProtocol:
TensorProtocol & Differentiable & EuclideanDifferentiable
where Scalar: TensorFlowFloatingPoint {
@differentiable(wrt: self)
func annotate(_ annotation: String) -> Self
}

extension Tensor: TensorProtocol & DifferentiableTensorProtocol
where Scalar: TensorFlowFloatingPoint {

public var annotations: String {
#if USING_X10_BACKEND
switch handle.backend {
case .XLA:
let rawAnnotations = XLATensor.annotations(xlaTensor)

// TODO(michellecasbon): Add formatting.

return rawAnnotations

case .TF_EAGER:
return Device.defaultTFEager.annotationsAvailable
}
#else
return "Annotations not available in TF_EAGER."
#endif
}

public var summary: String { annotations }

@differentiable(wrt: self)
public func annotate(_ annotation: String) -> Tensor<Scalar> {
#if USING_X10_BACKEND
switch handle.backend {
case .XLA:
return Tensor<Scalar>(_xla: XLATensor.annotate(xlaTensor, annotation))
case .TF_EAGER:
return self
}
#else
return self
#endif
}

@derivative(of: annotate)
@usableFromInline
func vjpAnnotate(_ annotation: String) -> (
value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>
) {
(annotate(annotation), { $0 })
}
}

extension Tensor: AnyTensor {
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }
Expand Down
13 changes: 13 additions & 0 deletions Sources/x10/swift_bindings/Device.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ public struct Device {
case .XLA: return "XLA"
}
}

var annotationsAvailable: String {
switch self {
case .TF_EAGER: return "Annotations not available in TF_EAGER."
case .XLA: return "Annotations available in XLA."
}
}
}

/// A device kind.
Expand Down Expand Up @@ -208,6 +215,12 @@ extension Device: CustomStringConvertible {
}
}

extension Device {
public var annotationsAvailable: String {
"\(backend.annotationsAvailable)"
}
}

extension CDevice {
var device: Device {
return Device(kind: hw_type.kind, ordinal: Int(ordinal), backend: .XLA)
Expand Down
29 changes: 21 additions & 8 deletions Sources/x10/swift_bindings/XLATensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ extension XLATensor {
}
}

static func annotate(_ a: XLATensor, _ annotation: String) -> XLATensor {
return XLATensor(_handle: XLATensor_annotate(a.handle, annotation))
}

static func annotations(_ a: XLATensor) -> String {
// TODO(michellecasbon): Format with header.
let str = XLATensor_get_annotations(a.handle)
defer { DeleteString(str) }
return String(cString: GetStringCStr(str))
}

static func any(_ input: XLATensor, _ reductionIndices: [Int64], _ keepDims: Bool) -> XLATensor {
defer { _fixLifetime(input) }
return reductionIndices.withArrayRef { reductionIndices in
Expand Down Expand Up @@ -407,7 +418,9 @@ extension XLATensor {
return XLATensor(_handle: XLATensor_div(a.handle, b.handle))
}

static func dynamic_slice(_ base: XLATensor, _ start_indices: [XLATensor], _ slice_shape: [Int64]) -> XLATensor {
static func dynamic_slice(_ base: XLATensor, _ start_indices: [XLATensor], _ slice_shape: [Int64])
-> XLATensor
{
start_indices.withArrayRef { start_indices in
slice_shape.withArrayRef { slice_shape in
return XLATensor(_handle: XLATensor_dynamic_slice(base.handle, start_indices, slice_shape))
Expand Down Expand Up @@ -491,6 +504,12 @@ extension XLATensor {
}
}

static func irText(_ a: XLATensor) -> String {
let str = XLATensor_ir_text(a.handle)
defer { DeleteString(str) }
return String(cString: GetStringCStr(str))
}

static func isFinite(_ input: XLATensor) -> XLATensor {
defer { _fixLifetime(input) }
return XLATensor(_handle: XLATensor_is_finite(input.handle))
Expand Down Expand Up @@ -761,7 +780,7 @@ extension XLATensor {
}

static func replica_id(_ device: Device) -> XLATensor {
return XLATensor(_handle: XLATensor_replica_id(device.cdevice));
return XLATensor(_handle: XLATensor_replica_id(device.cdevice))
}

static func resize_value(_ value: XLATensor, _ dims: [Int64]) -> XLATensor {
Expand Down Expand Up @@ -841,12 +860,6 @@ extension XLATensor {
}
}

static func irText(_ a: XLATensor) -> String {
let str = XLATensor_ir_text(a.handle)
defer { DeleteString(str) }
return String(cString: GetStringCStr(str))
}

static func sub(_ a: XLATensor, _ b: XLATensor) -> XLATensor {
defer { _fixLifetime(a) }
defer { _fixLifetime(b) }
Expand Down
1 change: 1 addition & 0 deletions Sources/x10/xla_tensor/aten_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@
_(aten, all) \
_(aten, allclose) \
_(aten, alpha_dropout) \
_(aten, annotate) \
_(aten, any) \
_(aten, arange) \
_(aten, argmax) \
Expand Down
32 changes: 32 additions & 0 deletions Sources/x10/xla_tensor/ir_dump_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,22 @@ struct ChangeLogNode {

thread_local std::map<xla::hash_t, std::vector<ChangeLogNode>> g_change_logs;

std::string GenerateTextAnnotation(const Node* node) {
// TODO(michellecasbon): Use json.
std::stringstream ss;
ss << " shape=[";
size_t i = 0;
for (auto& dimension : node->shape().dimensions()) {
if ((i++) != 0) ss << ", ";
ss << dimension;
}
ss << "] ";
for (auto& tag : GetNodeTags(node)) {
ss << tag.value;
}
return ss.str();
}

} // namespace

std::string DumpUtil::ToDot(absl::Span<const Node* const> nodes) {
Expand Down Expand Up @@ -323,5 +339,21 @@ std::string DumpUtil::GetGraphChangeLog(absl::Span<const Node* const> roots) {
return ss.str();
}

std::string DumpUtil::GetAnnotations(absl::Span<const Node* const> nodes) {
auto post_order = Util::ComputePostOrder(nodes);

NodeIdMap id_map = GenerateIdMap(post_order);
std::stringstream ss;
ss << "{";
for (auto node : post_order) {
// Only process annotations
if (node->op().ToString() != "x10::annotate") continue;

ss << "\n" << GenerateTextAnnotation(node);
}
ss << "\n" << "}";
return ss.str();
}

} // namespace ir
} // namespace swift_xla
2 changes: 2 additions & 0 deletions Sources/x10/xla_tensor/ir_dump_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class DumpUtil {
const Device& device);

static std::string GetGraphChangeLog(absl::Span<const Node* const> roots);

static std::string GetAnnotations(absl::Span<const Node* const> nodes);
};

} // namespace ir
Expand Down
48 changes: 48 additions & 0 deletions Sources/x10/xla_tensor/ops/annotate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2020 TensorFlow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "tensorflow/compiler/tf2xla/xla_tensor/ops/annotate.h"

#include "tensorflow/compiler/xla/xla_client/util.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"

namespace swift_xla {
namespace ir {
namespace ops {

Annotate::Annotate(const Value& input, std::string annotation)
: Node(ir::OpKind(at::aten::annotate), {input}, input.shape(),
/*num_outputs=*/1, xla::util::MHash()),
annotation_(annotation) {}

NodePtr Annotate::Clone(OpList operands) const {
return MakeNode<Annotate>(operands.at(0), annotation_);
}

XlaOpVector Annotate::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(input, loctx);
}

std::string Annotate::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", annotation=" << annotation_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace swift_xla

47 changes: 47 additions & 0 deletions Sources/x10/xla_tensor/ops/annotate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2020 TensorFlow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

// #include <vector>

#include "tensorflow/compiler/tf2xla/xla_tensor/ir.h"

namespace swift_xla {
namespace ir {
namespace ops {

// IR node for collecting layer statistics.
class Annotate : public Node {
public:
Annotate(const Value& input, std::string annotation);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::string& annotation() const { return annotation_; }

private:
std::string annotation_;
};

} // namespace ops
} // namespace ir
} // namespace swift_xla

2 changes: 2 additions & 0 deletions Sources/x10/xla_tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class XLATensor {
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions);

static XLATensor annotate(const XLATensor& input, std::string annotation);

static XLATensor any(const XLATensor& input,
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions);
Expand Down
Loading