Skip to content

Commit 304cd7b

Browse files
authored
Helpers to create zeros tensor. (#10643)
Summary: #8366 Reviewed By: kirklandsign Differential Revision: D74020937
1 parent 70974aa commit 304cd7b

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,69 @@ __attribute__((deprecated("This API is experimental.")))
902902

903903
@end
904904

905+
#pragma mark - Zeros Category
906+
907+
@interface ExecuTorchTensor (Zeros)
908+
909+
/**
910+
* Creates a tensor filled with zeros, with the specified shape, data type, and shape dynamism.
911+
*
912+
* @param shape An NSArray of NSNumber objects representing the desired shape.
913+
* @param dataType An ExecuTorchDataType value specifying the element type.
914+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
915+
* @return A new ExecuTorchTensor instance filled with zeros.
916+
*/
917+
+ (instancetype)zerosTensorWithShape:(NSArray<NSNumber *> *)shape
918+
dataType:(ExecuTorchDataType)dataType
919+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
920+
NS_SWIFT_NAME(zeros(shape:dataType:shapeDynamism:));
921+
922+
/**
923+
* Creates a tensor filled with zeros, with the specified shape and data type.
924+
*
925+
* @param shape An NSArray of NSNumber objects representing the desired shape.
926+
* @param dataType An ExecuTorchDataType value specifying the element type.
927+
* @return A new ExecuTorchTensor instance filled with zeros.
928+
*/
929+
+ (instancetype)zerosTensorWithShape:(NSArray<NSNumber *> *)shape
930+
dataType:(ExecuTorchDataType)dataType
931+
NS_SWIFT_NAME(zeros(shape:dataType:));
932+
933+
/**
934+
* Creates a tensor filled with zeros similar to an existing tensor, with the specified data type and shape dynamism.
935+
*
936+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
937+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
938+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
939+
* @return A new ExecuTorchTensor instance filled with zeros.
940+
*/
941+
+ (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor
942+
dataType:(ExecuTorchDataType)dataType
943+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
944+
NS_SWIFT_NAME(zeros(like:dataType:shapeDynamism:));
945+
946+
/**
947+
* Creates a tensor filled with zeros similar to an existing tensor, with the specified data type.
948+
*
949+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
950+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
951+
* @return A new ExecuTorchTensor instance filled with zeros.
952+
*/
953+
+ (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor
954+
dataType:(ExecuTorchDataType)dataType
955+
NS_SWIFT_NAME(zeros(like:dataType:));
956+
957+
/**
958+
* Creates a tensor filled with zeros similar to an existing tensor.
959+
*
960+
* @param tensor An existing ExecuTorchTensor instance.
961+
* @return A new ExecuTorchTensor instance filled with zeros.
962+
*/
963+
+ (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor
964+
NS_SWIFT_NAME(zeros(like:));
965+
966+
@end
967+
905968
#pragma mark - Random Category
906969

907970
@interface ExecuTorchTensor (Random)

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,56 @@ + (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor {
776776

777777
@end
778778

779+
@implementation ExecuTorchTensor (Zeros)
780+
781+
+ (instancetype)zerosTensorWithShape:(NSArray<NSNumber *> *)shape
782+
dataType:(ExecuTorchDataType)dataType
783+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
784+
return [self fullTensorWithShape:shape
785+
scalar:@(0)
786+
strides:@[]
787+
dataType:dataType
788+
shapeDynamism:shapeDynamism];
789+
}
790+
791+
+ (instancetype)zerosTensorWithShape:(NSArray<NSNumber *> *)shape
792+
dataType:(ExecuTorchDataType)dataType {
793+
return [self fullTensorWithShape:shape
794+
scalar:@(0)
795+
strides:@[]
796+
dataType:dataType
797+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
798+
}
799+
800+
+ (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor
801+
dataType:(ExecuTorchDataType)dataType
802+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
803+
return [self fullTensorWithShape:tensor.shape
804+
scalar:@(0)
805+
strides:tensor.strides
806+
dataType:dataType
807+
shapeDynamism:shapeDynamism];
808+
}
809+
810+
+ (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor
811+
dataType:(ExecuTorchDataType)dataType {
812+
return [self fullTensorWithShape:tensor.shape
813+
scalar:@(0)
814+
strides:tensor.strides
815+
dataType:dataType
816+
shapeDynamism:tensor.shapeDynamism];
817+
}
818+
819+
+ (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor {
820+
return [self fullTensorWithShape:tensor.shape
821+
scalar:@(0)
822+
strides:tensor.strides
823+
dataType:tensor.dataType
824+
shapeDynamism:tensor.shapeDynamism];
825+
}
826+
827+
@end
828+
779829
@implementation ExecuTorchTensor (Random)
780830

781831
+ (instancetype)randomTensorWithShape:(NSArray<NSNumber *> *)shape

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,31 @@ class TensorTest: XCTestCase {
618618
}
619619
}
620620
}
621+
622+
func testZeros() {
623+
let tensor = Tensor.zeros(shape: [2, 3], dataType: .double)
624+
XCTAssertEqual(tensor.shape, [2, 3])
625+
XCTAssertEqual(tensor.count, 6)
626+
tensor.bytes { pointer, count, dataType in
627+
XCTAssertEqual(dataType, .double)
628+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Double.self), count: count)
629+
for value in buffer {
630+
XCTAssertEqual(value, 0)
631+
}
632+
}
633+
}
634+
635+
func testZerosLike() {
636+
let other = Tensor.full(shape: [3, 2], scalar: 9, dataType: .int)
637+
let tensor = Tensor.zeros(like: other)
638+
XCTAssertEqual(tensor.shape, other.shape)
639+
tensor.bytes { pointer, count, dataType in
640+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count)
641+
for value in buffer {
642+
XCTAssertEqual(value, 0)
643+
}
644+
}
645+
}
621646

622647
func testRandom() {
623648
let tensor = Tensor.rand(shape: [3, 3], dataType: .float)

0 commit comments

Comments
 (0)