Skip to content

Commit c605516

Browse files
authored
Merge pull request #20188 from pschuh/3
Break TensorGroup into InputTensorGroup and OutputTensorGroup.
2 parents bd869ab + d980da6 commit c605516

File tree

8 files changed

+140
-67
lines changed

8 files changed

+140
-67
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ PROTOCOL(Numeric)
7373
PROTOCOL(FloatingPoint)
7474
PROTOCOL(ParameterGroup)
7575
PROTOCOL(Parameterized)
76-
PROTOCOL(TensorGroup)
76+
PROTOCOL(InputTensorGroup)
77+
PROTOCOL(OutputTensorGroup)
7778
PROTOCOL(TensorProtocol)
7879
PROTOCOL(TensorSendableReceivable)
7980
PROTOCOL(VectorNumeric)

lib/AST/ASTContext.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
972972
case KnownProtocolKind::AccelerableByTensorFlow:
973973
case KnownProtocolKind::ParameterGroup:
974974
case KnownProtocolKind::Parameterized:
975-
case KnownProtocolKind::TensorGroup:
975+
case KnownProtocolKind::InputTensorGroup:
976+
case KnownProtocolKind::OutputTensorGroup:
976977
case KnownProtocolKind::TensorSendableReceivable:
977978
case KnownProtocolKind::TensorProtocol:
978979
M = getLoadedModule(Id_TensorFlow);

lib/IRGen/GenMeta.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4055,7 +4055,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
40554055
case KnownProtocolKind::Numeric:
40564056
case KnownProtocolKind::ParameterGroup:
40574057
case KnownProtocolKind::Parameterized:
4058-
case KnownProtocolKind::TensorGroup:
4058+
case KnownProtocolKind::InputTensorGroup:
4059+
case KnownProtocolKind::OutputTensorGroup:
40594060
case KnownProtocolKind::TensorProtocol:
40604061
case KnownProtocolKind::TensorSendableReceivable:
40614062
case KnownProtocolKind::VectorNumeric:

lib/IRGen/IRGenSIL.cpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,9 +1979,12 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
19791979
// TODO: As an optimization, do this lookup once per CurSILFn
19801980
auto tfModule = astCtx.getLoadedModule(astCtx.Id_TensorFlow);
19811981
assert(tfModule && "could not find TensorFlow module");
1982-
auto tensorGroupProto =
1983-
astCtx.getProtocol(KnownProtocolKind::TensorGroup);
1984-
assert(tensorGroupProto && "could not find TensorGroup protocol");
1982+
auto inputTensorGroupProto =
1983+
astCtx.getProtocol(KnownProtocolKind::InputTensorGroup);
1984+
auto outputTensorGroupProto =
1985+
astCtx.getProtocol(KnownProtocolKind::OutputTensorGroup);
1986+
assert(inputTensorGroupProto && "could not find InputTensorGroup protocol");
1987+
assert(outputTensorGroupProto && "could not find OutputTensorGroup protocol");
19851988

19861989
if (!llvm::TFDynamicCompilation) {
19871990
// If we are not in dynamic compilation mode, then deabstraction may not
@@ -2051,25 +2054,18 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
20512054
// returns an Int32 value for the number of inputs that it has added. There
20522055
// are a few different cases that can be unpacked:
20532056
// - if `opInput` is a TensorFlow value, then we just add its handle;
2054-
// - if `opInput` is an archetype conforming to TensorGroup, then we
2057+
// - if `opInput` is an archetype conforming to InputTensorGroup, then we
20552058
// ask the conformance for the handles and add those;
20562059
// This function crashes if it receives an unhandled case. Earlier
20572060
// typechecking should ensure that inputs match the cases that this function
20582061
// handles.
2059-
//
2060-
// TODO: We should also handle the following cases:
2061-
// - if `opInput` is an array of TensorFlow values, then we add all the
2062-
// elements' handles;
2063-
// - if `opInput` is an array of archetypes conforming to TensorGroup,
2064-
// then we ask the conformance for all the elements' handles, flatten the
2065-
// results together, and add those;
20662062
auto unpackAndAddInput = [&](SILValue opInput) -> llvm::Value* {
20672063
LLVM_DEBUG(llvm::dbgs()
20682064
<< " Adding input of type " << opInput->getType() << ".\n");
20692065

20702066
// If this is a known TensorFlow value, add it directly.
20712067
// TODO: We could also handle concrete structs of known TensorFlow values
2072-
// here, to avoid falling through to the slower TensorGroup case.
2068+
// here, to avoid falling through to the slower InputTensorGroup case.
20732069
if (tf::isTensorFlowValue(opInput->getType())) {
20742070
auto *tensorHandleValue = getLoweredSingletonExplosion(opInput);
20752071
auto *opAddInputFromTensorHandleFn =
@@ -2082,13 +2078,13 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
20822078
return llvm::ConstantInt::get(IGM.Int32Ty, 1);
20832079
}
20842080

2085-
// Otherwise, this must conform to TensorGroup so we can add it using
2081+
// Otherwise, this must conform to InputTensorGroup so we can add it using
20862082
// TFC_OpAddInputFromTensorGroup.
20872083

20882084
auto canType = opInput->getType().getASTType()->getCanonicalType();
2089-
auto conformance = tfModule->lookupConformance(canType,
2090-
tensorGroupProto);
2091-
assert(conformance && "input type does not conform to TensorGroup");
2085+
auto conformance =
2086+
tfModule->lookupConformance(canType, inputTensorGroupProto);
2087+
assert(conformance && "input type does not conform to InputTensorGroup");
20922088
auto *typeMetadata = emitTypeMetadataRef(canType);
20932089
auto *wtable = emitWitnessTableRef(*this, canType, *conformance);
20942090

@@ -2228,8 +2224,8 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
22282224
outParameterCanType =
22292225
silValue->getType().getASTType()->getCanonicalType();
22302226
auto conformance = tfModule->lookupConformance(outParameterCanType,
2231-
tensorGroupProto);
2232-
assert(conformance && "out type does not conform to TensorGroup");
2227+
outputTensorGroupProto);
2228+
assert(conformance && "out type does not conform to OutputTensorGroup");
22332229
outParameterTypeMetadata = emitTypeMetadataRef(outParameterCanType);
22342230
outParameterTensorGroupWitnessTable =
22352231
emitWitnessTableRef(*this, outParameterCanType, *conformance);
@@ -2888,13 +2884,13 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
28882884
// TensorGroup for the number of outputs that it needs.
28892885

28902886
assert(hasOpaqueTensorGroupResults &&
2891-
"found an unexpected opaque TensorGroup result");
2887+
"found an unexpected opaque OutputTensorGroup result");
28922888

28932889
// Emit the type metadata and witness table.
28942890
auto canType = silResult->getType().getASTType()->getCanonicalType();
28952891
auto conformance = tfModule->lookupConformance(canType,
2896-
tensorGroupProto);
2897-
assert(conformance && "out type does not conform to TensorGroup");
2892+
outputTensorGroupProto);
2893+
assert(conformance && "out type does not conform to OutputTensorGroup");
28982894
auto *typeMetadata = emitTypeMetadataRef(canType);
28992895
directResultTypeMetadatas.push_back(typeMetadata);
29002896
auto *witnessTable =

lib/SILOptimizer/Mandatory/TFDeabstraction.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,9 +2198,9 @@ static bool unpackTensorAggregates(
21982198

21992199
auto tfModule = ctx.getLoadedModule(ctx.Id_TensorFlow);
22002200
assert(tfModule && "could not find TensorFlow module");
2201-
auto tensorGroupProto =
2202-
ctx.getProtocol(KnownProtocolKind::TensorGroup);
2203-
assert(tensorGroupProto && "could not find TensorGroup protocol");
2201+
auto inputTensorGroupProto =
2202+
ctx.getProtocol(KnownProtocolKind::InputTensorGroup);
2203+
assert(inputTensorGroupProto && "could not find TensorGroup protocol");
22042204

22052205
std::function<bool(SILValue)> recurse;
22062206
recurse = [&](SILValue aggregate) -> bool {
@@ -2209,6 +2209,12 @@ static bool unpackTensorAggregates(
22092209
inputList.push_back(aggregate);
22102210
return false;
22112211
}
2212+
if (acceptTensorGroupConformingLeaves &&
2213+
tfModule->lookupConformance(aggregateTy.getASTType(),
2214+
inputTensorGroupProto)) {
2215+
inputList.push_back(aggregate);
2216+
return false;
2217+
}
22122218
if (auto tupleTy = aggregateTy.getAs<TupleType>()) {
22132219
for (auto i : range(tupleTy->getNumElements())) {
22142220
auto eltIdx = std::pair<SILValue, unsigned>(aggregate, i);
@@ -2237,12 +2243,6 @@ static bool unpackTensorAggregates(
22372243
}
22382244
return false;
22392245
}
2240-
if (acceptTensorGroupConformingLeaves &&
2241-
tfModule->lookupConformance(aggregateTy.getASTType(),
2242-
tensorGroupProto)) {
2243-
inputList.push_back(aggregate);
2244-
return false;
2245-
}
22462246
return true;
22472247
};
22482248

stdlib/public/TensorFlow/CompilerRuntime.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,12 +1177,12 @@ func _TFCOpAddInputFromTensorHandle(_ op: CTFEOp,
11771177
/// Adds `t` as an input or inputs to `op`. Returns the number of inputs added.
11781178
@usableFromInline
11791179
@_silgen_name("_swift_tfc_OpAddInputFromTensorGroup")
1180-
func _TFCOpAddInputFromTensorGroup<T : TensorGroup>(
1180+
func _TFCOpAddInputFromTensorGroup<T : InputTensorGroup>(
11811181
_ op: CTFEOp, _ t: T, _ status: CTFStatus
11821182
) -> Int32 {
1183-
let count = Int(T._tensorHandleCount)
1183+
let count = t._inputTensorHandleCount
11841184
let buffer =
1185-
UnsafeMutableBufferPointer<CTensorHandle>.allocate(capacity: count)
1185+
UnsafeMutableBufferPointer<CTensorHandle>.allocate(capacity: Int(count))
11861186
defer { buffer.deallocate() }
11871187
t._unpackTensorHandles(into: buffer.baseAddress)
11881188
for handle in buffer {
@@ -1191,14 +1191,14 @@ func _TFCOpAddInputFromTensorGroup<T : TensorGroup>(
11911191
return 0
11921192
}
11931193
}
1194-
return T._tensorHandleCount
1194+
return count
11951195
}
11961196

11971197
/// Initializes a TensorGroup value, taking ownership of all the tensor
11981198
/// handles in `tensorHandles`.
11991199
@usableFromInline
12001200
@_silgen_name("_swift_tfc_InitTensorGroup")
1201-
func _TFCInitTensorGroup<T : TensorGroup>(
1201+
func _TFCInitTensorGroup<T : OutputTensorGroup>(
12021202
_ tensorHandles: UnsafeMutablePointer<CTensorHandle>
12031203
) -> T {
12041204
return T(_owning: tensorHandles)
@@ -1223,10 +1223,10 @@ func _TFCDeallocateCHandleBuffer(
12231223

12241224
/// Returns the number of CTensorHandles in a TensorGroup of type T.
12251225
@_silgen_name("_swift_tfc_GetTensorGroupCHandleCount")
1226-
public func _TFCGetTensorGroupCHandleCount<T : TensorGroup>(
1226+
public func _TFCGetTensorGroupCHandleCount<T : OutputTensorGroup>(
12271227
_ type: T.Type
12281228
) -> Int32 {
1229-
return T._tensorHandleCount
1229+
return T._outputTensorHandleCount
12301230
}
12311231

12321232
@inlinable

stdlib/public/TensorFlow/TensorGroup.swift

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,51 +16,72 @@
1616

1717
import CTensorFlow
1818

19-
/// A protocol for types that can be used as tensor operation inputs and
20-
/// outputs. When a TensorGroup is used as an input, it gets passed to the
21-
/// tensor operation as an input list whose elements are the tensor fields of
22-
/// the type. When a TensorGroup is used as an output, it gets initialized
23-
/// with its tensor fields set to the tensor operation's tensor outputs.
19+
/// A protocol for types that can be used as tensor operation inputs. When a
20+
/// TensorGroup is used as an input, it gets passed to the tensor operation as
21+
/// an input list whose elements are the tensor fields of the type.
2422
///
25-
/// TODO: Add a derived conformance to TensorGroup so that users don't have
26-
/// to write the conformance themselves.
27-
public protocol TensorGroup {
28-
/// The types of the tensor stored properties in this type.
29-
static var _typeList: [TensorDataType] { get }
30-
23+
/// This protocol is divided from OutputTensorGroup in order for the number of
24+
/// tensors to be determined at runtime. For example, Array<Tensor<Float>> may
25+
/// have an unknown number of elements compile time.
26+
public protocol InputTensorGroup {
3127
/// Writes the tensor handles to `address`, which must be allocated
3228
/// with enough capacity to hold `_tensorHandleCount` handles. The tensor
3329
/// handles written to `address` are borrowed: this container still
3430
/// owns them.
3531
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?)
3632

33+
var _inputTensorHandleCount : Int32 { get }
34+
}
35+
36+
/// A protocol for types that can be used as tensor operation outputs. When a
37+
/// TensorGroup is used as an output, it gets initialized with its tensor fields
38+
/// set to the tensor operation's tensor outputs.
39+
/// The number of tensors must be known at compile time.
40+
public protocol OutputTensorGroup {
41+
/// The types of the tensor stored properties in this type.
42+
static var _outputTypeList: [TensorDataType] { get }
43+
3744
/// Initializes a value of this type, taking ownership of the
3845
/// `_tensorHandleCount` tensors that are at `tensorHandles`.
3946
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?)
4047
}
4148

42-
public extension TensorGroup {
49+
public extension OutputTensorGroup {
4350
/// The number of tensor fields in this type.
44-
static var _tensorHandleCount: Int32 {
45-
return Int32(_typeList.count)
51+
static var _outputTensorHandleCount: Int32 {
52+
return Int32(_outputTypeList.count)
4653
}
4754

4855
/// An array of `nil`s with size equal to `_tensorHandleCount`. The `nil`
4956
/// represents unknown shape.
5057
static var _unknownShapeList: [TensorShape?] {
51-
return Array(repeating: nil, count: Int(_tensorHandleCount))
58+
return Array(repeating: nil, count: Int(_outputTensorHandleCount))
5259
}
5360
}
5461

62+
/// A protocol for types that can be used as tensor operation inputs and
63+
/// outputs. When a TensorGroup is used as an input, it gets passed to the
64+
/// tensor operation as an input list whose elements are the tensor fields of
65+
/// the type. When a TensorGroup is used as an output, it gets initialized
66+
/// with its tensor fields set to the tensor operation's tensor outputs.
67+
///
68+
/// TODO: Add a derived conformance to TensorGroup so that users don't have
69+
/// to write the conformance themselves.
70+
public protocol TensorGroup : InputTensorGroup & OutputTensorGroup {}
71+
5572
//===----------------------------------------------------------------------===//
5673
// Conform standard TensorFlow types to TensorGroup
5774
//===----------------------------------------------------------------------===//
5875

5976
extension TensorHandle : TensorGroup {
60-
public static var _typeList: [TensorDataType] {
77+
public static var _outputTypeList: [TensorDataType] {
6178
return [Scalar.tensorFlowDataType]
6279
}
6380

81+
public var _inputTensorHandleCount : Int32 {
82+
get { return Int32(TensorHandle._outputTypeList.count) }
83+
}
84+
6485
public func _unpackTensorHandles(
6586
into address: UnsafeMutablePointer<CTensorHandle>?) {
6687
address!.initialize(to: _cTensorHandle)
@@ -72,10 +93,14 @@ extension TensorHandle : TensorGroup {
7293
}
7394

7495
extension ResourceHandle : TensorGroup {
75-
public static var _typeList: [TensorDataType] {
96+
public static var _outputTypeList: [TensorDataType] {
7697
return [TensorDataType(TF_RESOURCE)]
7798
}
7899

100+
public var _inputTensorHandleCount : Int32 {
101+
get { return Int32(ResourceHandle._outputTypeList.count) }
102+
}
103+
79104
public func _unpackTensorHandles(
80105
into address: UnsafeMutablePointer<CTensorHandle>?) {
81106
address!.initialize(to: _cTensorHandle)
@@ -87,10 +112,14 @@ extension ResourceHandle : TensorGroup {
87112
}
88113

89114
extension VariantHandle : TensorGroup {
90-
public static var _typeList: [TensorDataType] {
115+
public static var _outputTypeList: [TensorDataType] {
91116
return [TensorDataType(TF_VARIANT)]
92117
}
93118

119+
public var _inputTensorHandleCount : Int32 {
120+
get { return Int32(VariantHandle._outputTypeList.count) }
121+
}
122+
94123
public func _unpackTensorHandles(
95124
into address: UnsafeMutablePointer<CTensorHandle>?) {
96125
address!.initialize(to: _cTensorHandle)
@@ -102,10 +131,14 @@ extension VariantHandle : TensorGroup {
102131
}
103132

104133
extension Tensor : TensorGroup {
105-
public static var _typeList: [TensorDataType] {
134+
public static var _outputTypeList: [TensorDataType] {
106135
return [Scalar.tensorFlowDataType]
107136
}
108137

138+
public var _inputTensorHandleCount : Int32 {
139+
get { return Int32(Tensor._outputTypeList.count) }
140+
}
141+
109142
public func _unpackTensorHandles(
110143
into address: UnsafeMutablePointer<CTensorHandle>?) {
111144
address!.initialize(to: handle._cTensorHandle)
@@ -117,10 +150,14 @@ extension Tensor : TensorGroup {
117150
}
118151

119152
extension TensorElementLiteral : TensorGroup {
120-
public static var _typeList: [TensorDataType] {
153+
public static var _outputTypeList: [TensorDataType] {
121154
return [Scalar.tensorFlowDataType]
122155
}
123156

157+
public var _inputTensorHandleCount : Int32 {
158+
get { return Int32(TensorElementLiteral._outputTypeList.count) }
159+
}
160+
124161
public func _unpackTensorHandles(
125162
into address: UnsafeMutablePointer<CTensorHandle>?) {
126163
address!.initialize(to: handle._cTensorHandle)
@@ -130,3 +167,18 @@ extension TensorElementLiteral : TensorGroup {
130167
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
131168
}
132169
}
170+
171+
extension Array : InputTensorGroup where Element : InputTensorGroup {
172+
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
173+
var ptr = address
174+
for elem in self {
175+
elem._unpackTensorHandles(into: ptr)
176+
ptr = ptr!.advanced(by: Int(elem._inputTensorHandleCount))
177+
}
178+
}
179+
public var _inputTensorHandleCount : Int32 { get {
180+
var count: Int32 = 0
181+
for elem in self { count += elem._inputTensorHandleCount }
182+
return count
183+
} }
184+
}

0 commit comments

Comments
 (0)