Skip to content

Commit a5fcdc2

Browse files
committed
Bug fixes.
1 parent 701e31d commit a5fcdc2

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

lib/Sema/DerivedConformanceTensorArrayProtocol.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,9 +411,8 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
411411
auto *tensorArrayProto = C.getProtocol(
412412
KnownProtocolKind::TensorArrayProtocol);
413413
auto initName = DeclName(
414-
C, DeclBaseName::createConstructor(),
415-
{C.getIdentifier("_owning"), C.getIdentifier("count")});
416-
auto *initReq = getProtocolRequirement(tensorArrayProto, initName);
414+
C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")});
415+
auto *initReq = getProtocolRequirement(tensorGroupProto, initName);
417416
auto *tensorHandleCountReq = getProtocolRequirement(
418417
tensorArrayProto, C.Id_tensorHandleCount);
419418

lib/Sema/DerivedConformances.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
344344
// SWIFT_ENABLE_TENSORFLOW
345345
// TensorArrayProtocol.init(_owning:count)
346346
if (argumentNames[0] == ctx.getIdentifier("_owning") &&
347-
argumentNames[0] == ctx.getIdentifier("count")) {
347+
argumentNames[1] == ctx.getIdentifier("count")) {
348348
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
349349
}
350350
}

test/TensorFlowRuntime/tensor_array_protocol.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@ struct Simple : TensorGroup {
1717
var w, b: Tensor<Float>
1818
}
1919

20-
struct Mixed : TensorArrayProtocol {
20+
struct Mixed : TensorGroup {
2121
// Mutable.
2222
var string: StringTensor
2323
var float: Tensor<Float>
2424
// Immutable.
2525
let int: Tensor<Int32>
2626
}
2727

28-
struct Nested : TensorArrayProtocol {
28+
struct Nested : TensorGroup {
2929
// Immutable.
3030
let simple: Simple
3131
// Mutable.
3232
var mixed: Mixed
3333
}
3434

35-
struct Generic<T: TensorGroup, U: TensorGroup> : TensorArrayProtocol {
35+
struct Generic<T: TensorGroup, U: TensorGroup> : TensorGroup {
3636
var t: T
3737
var u: U
3838
}
@@ -157,7 +157,7 @@ TensorArrayProtocolTests.test("GenericUnpackTensorHandles") {
157157
TensorArrayProtocolTests.test("NestedGenericTensorHandleCount") {
158158
struct NestedGeneric {
159159
func function() {
160-
struct UltraNested<T: TensorArrayProtocol, V: TensorArrayProtocol> : TensorArrayProtocol {
160+
struct UltraNested<T: TensorGroup, V: TensorGroup> : TensorArrayProtocol {
161161
var a: Generic<T, V>
162162
var b: Generic<V, T>
163163
}
@@ -181,7 +181,7 @@ TensorArrayProtocolTests.test("NestedGenericTensorHandleCount") {
181181
TensorArrayProtocolTests.test("NestedGenericUnpackTensorHandles") {
182182
struct NestedGeneric {
183183
func function() {
184-
struct UltraNested<T: TensorArrayProtocol, V: TensorArrayProtocol> : TensorArrayProtocol {
184+
struct UltraNested<T: TensorGroup, V: TensorGroup> : TensorArrayProtocol {
185185
var a: Generic<T, V>
186186
var b: Generic<V, T>
187187
}

0 commit comments

Comments
 (0)