Skip to content

Break TensorGroup into InputTensorGroup and OutputTensorGroup. #20188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ PROTOCOL(Numeric)
PROTOCOL(FloatingPoint)
PROTOCOL(ParameterGroup)
PROTOCOL(Parameterized)
PROTOCOL(TensorGroup)
PROTOCOL(InputTensorGroup)
PROTOCOL(OutputTensorGroup)
PROTOCOL(TensorProtocol)
PROTOCOL(TensorSendableReceivable)
PROTOCOL(VectorNumeric)
Expand Down
3 changes: 2 additions & 1 deletion lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::AccelerableByTensorFlow:
case KnownProtocolKind::ParameterGroup:
case KnownProtocolKind::Parameterized:
case KnownProtocolKind::TensorGroup:
case KnownProtocolKind::InputTensorGroup:
case KnownProtocolKind::OutputTensorGroup:
case KnownProtocolKind::TensorSendableReceivable:
case KnownProtocolKind::TensorProtocol:
M = getLoadedModule(Id_TensorFlow);
Expand Down
3 changes: 2 additions & 1 deletion lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4055,7 +4055,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::Numeric:
case KnownProtocolKind::ParameterGroup:
case KnownProtocolKind::Parameterized:
case KnownProtocolKind::TensorGroup:
case KnownProtocolKind::InputTensorGroup:
case KnownProtocolKind::OutputTensorGroup:
case KnownProtocolKind::TensorProtocol:
case KnownProtocolKind::TensorSendableReceivable:
case KnownProtocolKind::VectorNumeric:
Expand Down
38 changes: 17 additions & 21 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1979,9 +1979,12 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
// TODO: As an optimization, do this lookup once per CurSILFn
auto tfModule = astCtx.getLoadedModule(astCtx.Id_TensorFlow);
assert(tfModule && "could not find TensorFlow module");
auto tensorGroupProto =
astCtx.getProtocol(KnownProtocolKind::TensorGroup);
assert(tensorGroupProto && "could not find TensorGroup protocol");
auto inputTensorGroupProto =
astCtx.getProtocol(KnownProtocolKind::InputTensorGroup);
auto outputTensorGroupProto =
astCtx.getProtocol(KnownProtocolKind::OutputTensorGroup);
assert(inputTensorGroupProto && "could not find InputTensorGroup protocol");
assert(outputTensorGroupProto && "could not find OutputTensorGroup protocol");

if (!llvm::TFDynamicCompilation) {
// If we are not in dynamic compilation mode, then deabstraction may not
Expand Down Expand Up @@ -2051,25 +2054,18 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
// returns an Int32 value for the number of inputs that it has added. There
// are a few different cases that can be unpacked:
// - if `opInput` is a TensorFlow value, then we just add its handle;
// - if `opInput` is an archetype conforming to TensorGroup, then we
// - if `opInput` is an archetype conforming to InputTensorGroup, then we
// ask the conformance for the handles and add those;
// This function crashes if it receives an unhandled case. Earlier
// typechecking should ensure that inputs match the cases that this function
// handles.
//
// TODO: We should also handle the following cases:
// - if `opInput` is an array of TensorFlow values, then we add all the
// elements' handles;
// - if `opInput` is an array of archetypes conforming to TensorGroup,
// then we ask the conformance for all the elements' handles, flatten the
// results together, and add those;
auto unpackAndAddInput = [&](SILValue opInput) -> llvm::Value* {
LLVM_DEBUG(llvm::dbgs()
<< " Adding input of type " << opInput->getType() << ".\n");

// If this is a known TensorFlow value, add it directly.
// TODO: We could also handle concrete structs of known TensorFlow values
// here, to avoid falling through to the slower TensorGroup case.
// here, to avoid falling through to the slower InputTensorGroup case.
if (tf::isTensorFlowValue(opInput->getType())) {
auto *tensorHandleValue = getLoweredSingletonExplosion(opInput);
auto *opAddInputFromTensorHandleFn =
Expand All @@ -2082,13 +2078,13 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
return llvm::ConstantInt::get(IGM.Int32Ty, 1);
}

// Otherwise, this must conform to TensorGroup so we can add it using
// Otherwise, this must conform to InputTensorGroup so we can add it using
// TFC_OpAddInputFromTensorGroup.

auto canType = opInput->getType().getASTType()->getCanonicalType();
auto conformance = tfModule->lookupConformance(canType,
tensorGroupProto);
assert(conformance && "input type does not conform to TensorGroup");
auto conformance =
tfModule->lookupConformance(canType, inputTensorGroupProto);
assert(conformance && "input type does not conform to InputTensorGroup");
auto *typeMetadata = emitTypeMetadataRef(canType);
auto *wtable = emitWitnessTableRef(*this, canType, *conformance);

Expand Down Expand Up @@ -2228,8 +2224,8 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
outParameterCanType =
silValue->getType().getASTType()->getCanonicalType();
auto conformance = tfModule->lookupConformance(outParameterCanType,
tensorGroupProto);
assert(conformance && "out type does not conform to TensorGroup");
outputTensorGroupProto);
assert(conformance && "out type does not conform to OutputTensorGroup");
outParameterTypeMetadata = emitTypeMetadataRef(outParameterCanType);
outParameterTensorGroupWitnessTable =
emitWitnessTableRef(*this, outParameterCanType, *conformance);
Expand Down Expand Up @@ -2888,13 +2884,13 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
// TensorGroup for the number of outputs that it needs.

assert(hasOpaqueTensorGroupResults &&
"found an unexpected opaque TensorGroup result");
"found an unexpected opaque OutputTensorGroup result");

// Emit the type metadata and witness table.
auto canType = silResult->getType().getASTType()->getCanonicalType();
auto conformance = tfModule->lookupConformance(canType,
tensorGroupProto);
assert(conformance && "out type does not conform to TensorGroup");
outputTensorGroupProto);
assert(conformance && "out type does not conform to OutputTensorGroup");
auto *typeMetadata = emitTypeMetadataRef(canType);
directResultTypeMetadatas.push_back(typeMetadata);
auto *witnessTable =
Expand Down
18 changes: 9 additions & 9 deletions lib/SILOptimizer/Mandatory/TFDeabstraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2198,9 +2198,9 @@ static bool unpackTensorAggregates(

auto tfModule = ctx.getLoadedModule(ctx.Id_TensorFlow);
assert(tfModule && "could not find TensorFlow module");
auto tensorGroupProto =
ctx.getProtocol(KnownProtocolKind::TensorGroup);
assert(tensorGroupProto && "could not find TensorGroup protocol");
auto inputTensorGroupProto =
ctx.getProtocol(KnownProtocolKind::InputTensorGroup);
assert(inputTensorGroupProto && "could not find TensorGroup protocol");

std::function<bool(SILValue)> recurse;
recurse = [&](SILValue aggregate) -> bool {
Expand All @@ -2209,6 +2209,12 @@ static bool unpackTensorAggregates(
inputList.push_back(aggregate);
return false;
}
if (acceptTensorGroupConformingLeaves &&
tfModule->lookupConformance(aggregateTy.getASTType(),
inputTensorGroupProto)) {
inputList.push_back(aggregate);
return false;
}
if (auto tupleTy = aggregateTy.getAs<TupleType>()) {
for (auto i : range(tupleTy->getNumElements())) {
auto eltIdx = std::pair<SILValue, unsigned>(aggregate, i);
Expand Down Expand Up @@ -2237,12 +2243,6 @@ static bool unpackTensorAggregates(
}
return false;
}
if (acceptTensorGroupConformingLeaves &&
tfModule->lookupConformance(aggregateTy.getASTType(),
tensorGroupProto)) {
inputList.push_back(aggregate);
return false;
}
return true;
};

Expand Down
14 changes: 7 additions & 7 deletions stdlib/public/TensorFlow/CompilerRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1177,12 +1177,12 @@ func _TFCOpAddInputFromTensorHandle(_ op: CTFEOp,
/// Adds `t` as an input or inputs to `op`. Returns the number of inputs added.
@usableFromInline
@_silgen_name("_swift_tfc_OpAddInputFromTensorGroup")
func _TFCOpAddInputFromTensorGroup<T : TensorGroup>(
func _TFCOpAddInputFromTensorGroup<T : InputTensorGroup>(
_ op: CTFEOp, _ t: T, _ status: CTFStatus
) -> Int32 {
let count = Int(T._tensorHandleCount)
let count = t._inputTensorHandleCount
let buffer =
UnsafeMutableBufferPointer<CTensorHandle>.allocate(capacity: count)
UnsafeMutableBufferPointer<CTensorHandle>.allocate(capacity: Int(count))
defer { buffer.deallocate() }
t._unpackTensorHandles(into: buffer.baseAddress)
for handle in buffer {
Expand All @@ -1191,14 +1191,14 @@ func _TFCOpAddInputFromTensorGroup<T : TensorGroup>(
return 0
}
}
return T._tensorHandleCount
return count
}

/// Initializes a TensorGroup value, taking ownership of all the tensor
/// handles in `tensorHandles`.
@usableFromInline
@_silgen_name("_swift_tfc_InitTensorGroup")
func _TFCInitTensorGroup<T : TensorGroup>(
func _TFCInitTensorGroup<T : OutputTensorGroup>(
_ tensorHandles: UnsafeMutablePointer<CTensorHandle>
) -> T {
return T(_owning: tensorHandles)
Expand All @@ -1223,10 +1223,10 @@ func _TFCDeallocateCHandleBuffer(

/// Returns the number of CTensorHandles in a TensorGroup of type T.
@_silgen_name("_swift_tfc_GetTensorGroupCHandleCount")
public func _TFCGetTensorGroupCHandleCount<T : TensorGroup>(
public func _TFCGetTensorGroupCHandleCount<T : OutputTensorGroup>(
_ type: T.Type
) -> Int32 {
return T._tensorHandleCount
return T._outputTensorHandleCount
}

@inlinable
Expand Down
92 changes: 72 additions & 20 deletions stdlib/public/TensorFlow/TensorGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,72 @@

import CTensorFlow

/// A protocol for types that can be used as tensor operation inputs and
/// outputs. When a TensorGroup is used as an input, it gets passed to the
/// tensor operation as an input list whose elements are the tensor fields of
/// the type. When a TensorGroup is used as an output, it gets initialized
/// with its tensor fields set to the tensor operation's tensor outputs.
/// A protocol for types that can be used as tensor operation inputs. When a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"tensor operation inputs" in an implementation-specific comment. It does not have a meaning to the user.

/// TensorGroup is used as an input, it gets passed to the tensor operation as
/// an input list whose elements are the tensor fields of the type.
///
/// TODO: Add a derived conformance to TensorGroup so that users don't have
/// to write the conformance themselves.
public protocol TensorGroup {
/// The types of the tensor stored properties in this type.
static var _typeList: [TensorDataType] { get }

/// This protocol is divided from OutputTensorGroup in order for the number of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not valid Swift API doc comment because it's implementation-specific, not something the user is expected to understand. If you wanted it to be seen the library implementer, please use double slashes.

/// tensors to be determined at runtime. For example, Array<Tensor<Float>> may
/// have an unknown number of elements compile time.
public protocol InputTensorGroup {
/// Writes the tensor handles to `address`, which must be allocated
/// with enough capacity to hold `_tensorHandleCount` handles. The tensor
/// handles written to `address` are borrowed: this container still
/// owns them.
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?)

var _inputTensorHandleCount : Int32 { get }
}

/// A protocol for types that can be used as tensor operation outputs. When a
/// TensorGroup is used as an output, it gets initialized with its tensor fields
/// set to the tensor operation's tensor outputs.
/// The number of tensors must be known at compile time.
public protocol OutputTensorGroup {
/// The types of the tensor stored properties in this type.
static var _outputTypeList: [TensorDataType] { get }

/// Initializes a value of this type, taking ownership of the
/// `_tensorHandleCount` tensors that are at `tensorHandles`.
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?)
}

public extension TensorGroup {
public extension OutputTensorGroup {
/// The number of tensor fields in this type.
static var _tensorHandleCount: Int32 {
return Int32(_typeList.count)
static var _outputTensorHandleCount: Int32 {
return Int32(_outputTypeList.count)
}

/// An array of `nil`s with size equal to `_tensorHandleCount`. The `nil`
/// represents unknown shape.
static var _unknownShapeList: [TensorShape?] {
return Array(repeating: nil, count: Int(_tensorHandleCount))
return Array(repeating: nil, count: Int(_outputTensorHandleCount))
}
}

/// A protocol for types that can be used as tensor operation inputs and
/// outputs. When a TensorGroup is used as an input, it gets passed to the
/// tensor operation as an input list whose elements are the tensor fields of
/// the type. When a TensorGroup is used as an output, it gets initialized
/// with its tensor fields set to the tensor operation's tensor outputs.
///
/// TODO: Add a derived conformance to TensorGroup so that users don't have
/// to write the conformance themselves.
public protocol TensorGroup : InputTensorGroup & OutputTensorGroup {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not idiomatic. Please change this to a typealias instead.

public typealias TensorGroup = InoutTensorGroup & OutputTensorGroup


//===----------------------------------------------------------------------===//
// Conform standard TensorFlow types to TensorGroup
//===----------------------------------------------------------------------===//

extension TensorHandle : TensorGroup {
public static var _typeList: [TensorDataType] {
public static var _outputTypeList: [TensorDataType] {
return [Scalar.tensorFlowDataType]
}

public var _inputTensorHandleCount : Int32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the space before :.

get { return Int32(TensorHandle._outputTypeList.count) }
}

public func _unpackTensorHandles(
into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
Expand All @@ -72,10 +93,14 @@ extension TensorHandle : TensorGroup {
}

extension ResourceHandle : TensorGroup {
public static var _typeList: [TensorDataType] {
public static var _outputTypeList: [TensorDataType] {
return [TensorDataType(TF_RESOURCE)]
}

public var _inputTensorHandleCount : Int32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the space before :.

get { return Int32(ResourceHandle._outputTypeList.count) }
}

public func _unpackTensorHandles(
into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
Expand All @@ -87,10 +112,14 @@ extension ResourceHandle : TensorGroup {
}

extension VariantHandle : TensorGroup {
public static var _typeList: [TensorDataType] {
public static var _outputTypeList: [TensorDataType] {
return [TensorDataType(TF_VARIANT)]
}

public var _inputTensorHandleCount : Int32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the space before :.

get { return Int32(VariantHandle._outputTypeList.count) }
}

public func _unpackTensorHandles(
into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
Expand All @@ -102,10 +131,14 @@ extension VariantHandle : TensorGroup {
}

extension Tensor : TensorGroup {
public static var _typeList: [TensorDataType] {
public static var _outputTypeList: [TensorDataType] {
return [Scalar.tensorFlowDataType]
}

public var _inputTensorHandleCount : Int32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the space before :.

get { return Int32(Tensor._outputTypeList.count) }
}

public func _unpackTensorHandles(
into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: handle._cTensorHandle)
Expand All @@ -117,10 +150,14 @@ extension Tensor : TensorGroup {
}

extension TensorElementLiteral : TensorGroup {
public static var _typeList: [TensorDataType] {
public static var _outputTypeList: [TensorDataType] {
return [Scalar.tensorFlowDataType]
}

public var _inputTensorHandleCount : Int32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the space before :.

get { return Int32(TensorElementLiteral._outputTypeList.count) }
}

public func _unpackTensorHandles(
into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: handle._cTensorHandle)
Expand All @@ -130,3 +167,18 @@ extension TensorElementLiteral : TensorGroup {
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
}
}

extension Array : InputTensorGroup where Element : InputTensorGroup {
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
var ptr = address
for elem in self {
elem._unpackTensorHandles(into: ptr)
ptr = ptr!.advanced(by: Int(elem._inputTensorHandleCount))
}
}
public var _inputTensorHandleCount : Int32 { get {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reformat this as the following:

public var _inputTensorHandleCount : Int32 {
  get {
    ...
  }
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be even better to do it without the get, like this?

public var _inputTensorHandleCount: Int32 {
  return ...
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it definitely would. I didn't notice there's no set or there's no special attributes on get.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove space before :.

var count: Int32 = 0
for elem in self { count += elem._inputTensorHandleCount }
return count
} }
}
Loading