Skip to content

Commit bde610a

Browse files
committed
Added a '_typeList' property to 'TensorArrayProtocol'.
1 parent 0647ce2 commit bde610a

File tree

4 files changed

+106
-4
lines changed

4 files changed

+106
-4
lines changed

lib/Sema/DerivedConformanceTensorArrayProtocol.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,89 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
363363
return tensorHandleCountDecl;
364364
}
365365

366+
367+
/// Derive the body for the '_typeList' getter.
368+
static void
369+
deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl) {
370+
auto *parentDC = funcDecl->getParent();
371+
auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
372+
auto &C = nominal->getASTContext();
373+
374+
auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
375+
auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList);
376+
377+
// Concatenate all member `_typeList` arrays.
378+
Type arrayType = BoundGenericType::get(
379+
C.getArrayDecl(), Type(),
380+
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
381+
auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C);
382+
auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
383+
assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
384+
ValueDecl *plusOpDecl = plusOpLookup.front();
385+
auto plusOpDRE = new (C)
386+
DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
387+
auto plusOpExpr = new (C)
388+
DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr);
389+
Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
390+
for (auto member : nominal->getStoredProperties()) {
391+
auto memberType =
392+
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
393+
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
394+
auto *memberTypeListExpr = new (C)
395+
MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq,
396+
DeclNameLoc(), /*Implicit*/ true);
397+
// Create expression `lhsArg + rhsArg`.
398+
auto *plusOpArgs =
399+
TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr},
400+
{}, {}, SourceLoc(), /*HasTrailingClosure*/ false,
401+
/*Implicit*/ true);
402+
typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs,
403+
/*Implicit*/ true);
404+
}
405+
406+
// Return the resulting data types array.
407+
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr);
408+
auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
409+
/*Implicit*/ true);
410+
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
411+
/*Implicit*/ true));
412+
}
413+
414+
/// Derive a '_typeList' implementation.
415+
static ValueDecl *deriveTensorArrayProtocol_typeList(
416+
DerivedConformance &derived) {
417+
auto nominal = derived.Nominal;
418+
auto &TC = derived.TC;
419+
ASTContext &C = TC.Context;
420+
421+
auto parentDC = derived.getConformanceContext();
422+
Type dataTypeArrayType = BoundGenericType::get(
423+
C.getArrayDecl(), Type(),
424+
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
425+
auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType);
426+
427+
// Create `_typeList` property declaration.
428+
VarDecl *typeListDecl;
429+
PatternBindingDecl *patDecl;
430+
std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty(
431+
C.Id_typeList, returnType, returnType, /*isStatic*/ false,
432+
/*isFinal*/ false);
433+
434+
// Add `@inlinable` to the `_typeList` declaration.
435+
if (nominal->getEffectiveAccess() > AccessLevel::Internal)
436+
typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));
437+
438+
// Create `_typeList` getter.
439+
auto *getterDecl = derived.declareDerivedPropertyGetter(
440+
TC, typeListDecl, returnType);
441+
getterDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_typeList);
442+
typeListDecl->setAccessors(StorageImplInfo::getImmutableComputed(),
443+
SourceLoc(), {getterDecl}, SourceLoc());
444+
derived.addMembersToConformanceContext({getterDecl, typeListDecl, patDecl});
445+
446+
return typeListDecl;
447+
}
448+
366449
// Synthesize body for `init(_owning:count:)`.
367450
static void
368451
deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
@@ -584,6 +667,8 @@ ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
584667
return deriveTensorArrayProtocol_unpackTensorHandles(*this);
585668
if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount)
586669
return deriveTensorArrayProtocol_tensorHandleCount(*this);
670+
if (requirement->getBaseName() == TC.Context.Id_typeList)
671+
return deriveTensorArrayProtocol_typeList(*this);
587672
if (requirement->getBaseName() == DeclBaseName::createConstructor())
588673
return deriveTensorArrayProtocol_init(*this);
589674
TC.diagnose(requirement->getLoc(),

lib/Sema/DerivedConformances.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
230230
// TensorArrayProtocol._tensorHandleCount
231231
if (name.isSimpleName(ctx.Id_tensorHandleCount))
232232
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
233+
234+
// SWIFT_ENABLE_TENSORFLOW
235+
// TensorArrayProtocol._typeList
236+
if (name.isSimpleName(ctx.Id_typeList) && !requirement->isStatic())
237+
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
233238

234239
// SWIFT_ENABLE_TENSORFLOW
235240
// TensorGroup._typeList

stdlib/public/TensorFlow/TensorGroup.swift

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public protocol TensorArrayProtocol {
3636
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?)
3737

3838
var _tensorHandleCount: Int32 { get }
39+
var _typeList: [TensorDataType] { get }
3940

4041
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
4142
}
@@ -69,13 +70,16 @@ public protocol TensorGroup : TensorArrayProtocol {
6970
public extension TensorGroup {
7071
/// The number of tensor fields in this type.
7172
static var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) }
72-
var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) }
7373

7474
/// An array of `nil`s with the same number of elements as `_outputTypeList`.
7575
/// The `nil` represents unknown shape.
7676
static var _unknownShapeList: [TensorShape?] {
7777
return Array(repeating: nil, count: _typeList.count)
7878
}
79+
80+
// The following instance properties are from `TensorArrayProtocol`.
81+
var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) }
82+
var _typeList: [TensorDataType] { return Self._typeList }
7983

8084
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
8185
precondition(count == Self._typeList.count)
@@ -223,9 +227,13 @@ extension Array : TensorArrayProtocol where Element : TensorGroup {
223227
}
224228

225229
public var _tensorHandleCount: Int32 {
226-
var count: Int32 = 0
227-
for elem in self { count += elem._tensorHandleCount }
228-
return count
230+
return Element._tensorHandleCount * Int32(count)
231+
}
232+
233+
public var _typeList: [TensorDataType] {
234+
return Array<TensorDataType>([[TensorDataType]](
235+
repeating: Element._typeList,
236+
count: Int(Element._tensorHandleCount)).joined())
229237
}
230238

231239
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {

test/TensorFlowRuntime/tracer.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ TracerTests.testAllBackends("Advanced") {
191191
return model._tensorHandleCount + optimizer._tensorHandleCount
192192
}
193193

194+
public var _typeList: [TensorDataType] {
195+
return model._typeList + optimizer._typeList
196+
}
197+
194198
func _makeInstance<C: Collection>(owning inputs: C) -> State
195199
where C.Element == CTensorHandle {
196200
assert(inputs.count == 4)

0 commit comments

Comments
 (0)