Skip to content

Commit 444d7f9

Browse files
authored
Helpers to create random normal tensor.
Differential Revision: D74020939 Pull Request resolved: #10648
1 parent 74dbf15 commit 444d7f9

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,4 +980,87 @@ __attribute__((deprecated("This API is experimental.")))
980980

981981
@end
982982

983+
#pragma mark - RandomNormal Category
984+
985+
@interface ExecuTorchTensor (RandomNormal)
986+
987+
/**
988+
* Creates a tensor with random values drawn from a normal distribution,
989+
* with full specification of shape, strides, data type, and shape dynamism.
990+
*
991+
* @param shape An NSArray of NSNumber objects representing the desired shape.
992+
* @param strides An NSArray of NSNumber objects representing the desired strides.
993+
* @param dataType An ExecuTorchDataType value specifying the element type.
994+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
995+
* @return A new ExecuTorchTensor instance filled with values from a normal distribution.
996+
*/
997+
+ (instancetype)randomNormalTensorWithShape:(NSArray<NSNumber *> *)shape
998+
strides:(NSArray<NSNumber *> *)strides
999+
dataType:(ExecuTorchDataType)dataType
1000+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1001+
NS_SWIFT_NAME(randn(shape:strides:dataType:shapeDynamism:));
1002+
1003+
/**
1004+
* Creates a tensor with random values drawn from a normal distribution,
1005+
* with the specified shape and data type.
1006+
*
1007+
* @param shape An NSArray of NSNumber objects representing the desired shape.
1008+
* @param dataType An ExecuTorchDataType value specifying the element type.
1009+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
1010+
* @return A new ExecuTorchTensor instance filled with values from a normal distribution.
1011+
*/
1012+
+ (instancetype)randomNormalTensorWithShape:(NSArray<NSNumber *> *)shape
1013+
dataType:(ExecuTorchDataType)dataType
1014+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1015+
NS_SWIFT_NAME(randn(shape:dataType:shapeDynamism:));
1016+
1017+
/**
1018+
* Creates a tensor with random values drawn from a normal distribution,
1019+
* with the specified shape (using dynamic bound shape) and data type.
1020+
*
1021+
* @param shape An NSArray of NSNumber objects representing the desired shape.
1022+
* @param dataType An ExecuTorchDataType value specifying the element type.
1023+
* @return A new ExecuTorchTensor instance filled with values from a normal distribution.
1024+
*/
1025+
+ (instancetype)randomNormalTensorWithShape:(NSArray<NSNumber *> *)shape
1026+
dataType:(ExecuTorchDataType)dataType
1027+
NS_SWIFT_NAME(randn(shape:dataType:));
1028+
1029+
/**
1030+
* Creates a tensor with random normal values similar to an existing tensor,
1031+
* with the specified data type and shape dynamism.
1032+
*
1033+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
1034+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
1035+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
1036+
* @return A new ExecuTorchTensor instance filled with values from a normal distribution.
1037+
*/
1038+
+ (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor
1039+
dataType:(ExecuTorchDataType)dataType
1040+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1041+
NS_SWIFT_NAME(randn(like:dataType:shapeDynamism:));
1042+
1043+
/**
1044+
* Creates a tensor with random normal values similar to an existing tensor,
1045+
* with the specified data type.
1046+
*
1047+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
1048+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
1049+
* @return A new ExecuTorchTensor instance filled with values from a normal distribution.
1050+
*/
1051+
+ (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor
1052+
dataType:(ExecuTorchDataType)dataType
1053+
NS_SWIFT_NAME(randn(like:dataType:));
1054+
1055+
/**
1056+
* Creates a tensor with random normal values similar to an existing tensor.
1057+
*
1058+
* @param tensor An existing ExecuTorchTensor instance.
1059+
* @return A new ExecuTorchTensor instance filled with values from a normal distribution.
1060+
*/
1061+
+ (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor
1062+
NS_SWIFT_NAME(randn(like:));
1063+
1064+
@end
1065+
9831066
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,3 +833,61 @@ + (instancetype)randomTensorLikeTensor:(ExecuTorchTensor *)tensor {
833833
}
834834

835835
@end
836+
837+
@implementation ExecuTorchTensor (RandomNormal)
838+
839+
+ (instancetype)randomNormalTensorWithShape:(NSArray<NSNumber *> *)shape
840+
strides:(NSArray<NSNumber *> *)strides
841+
dataType:(ExecuTorchDataType)dataType
842+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
843+
auto tensor = randn_strided(
844+
utils::toVector<SizesType>(shape),
845+
utils::toVector<StridesType>(strides),
846+
static_cast<ScalarType>(dataType),
847+
static_cast<TensorShapeDynamism>(shapeDynamism)
848+
);
849+
return [[self alloc] initWithNativeInstance:&tensor];
850+
}
851+
852+
+ (instancetype)randomNormalTensorWithShape:(NSArray<NSNumber *> *)shape
853+
dataType:(ExecuTorchDataType)dataType
854+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
855+
return [self randomNormalTensorWithShape:shape
856+
strides:@[]
857+
dataType:dataType
858+
shapeDynamism:shapeDynamism];
859+
}
860+
861+
+ (instancetype)randomNormalTensorWithShape:(NSArray<NSNumber *> *)shape
862+
dataType:(ExecuTorchDataType)dataType {
863+
return [self randomNormalTensorWithShape:shape
864+
strides:@[]
865+
dataType:dataType
866+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
867+
}
868+
869+
+ (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor
870+
dataType:(ExecuTorchDataType)dataType
871+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
872+
return [self randomNormalTensorWithShape:tensor.shape
873+
strides:tensor.strides
874+
dataType:dataType
875+
shapeDynamism:shapeDynamism];
876+
}
877+
878+
+ (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor
879+
dataType:(ExecuTorchDataType)dataType {
880+
return [self randomNormalTensorWithShape:tensor.shape
881+
strides:tensor.strides
882+
dataType:dataType
883+
shapeDynamism:tensor.shapeDynamism];
884+
}
885+
886+
+ (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor {
887+
return [self randomNormalTensorWithShape:tensor.shape
888+
strides:tensor.strides
889+
dataType:tensor.dataType
890+
shapeDynamism:tensor.shapeDynamism];
891+
}
892+
893+
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,4 +637,22 @@ class TensorTest: XCTestCase {
637637
XCTAssertEqual(tensor.shape, other.shape)
638638
XCTAssertEqual(tensor.count, other.count)
639639
}
640+
641+
func testRandomNormal() {
642+
let tensor = Tensor.randn(shape: [4], dataType: .double)
643+
XCTAssertEqual(tensor.shape, [4])
644+
XCTAssertEqual(tensor.count, 4)
645+
tensor.bytes { pointer, count, dataType in
646+
XCTAssertEqual(dataType, .double)
647+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Double.self), count: count)
648+
XCTAssertEqual(buffer.count, 4)
649+
}
650+
}
651+
652+
func testRandomNormalLike() {
653+
let other = Tensor.zeros(shape: [4], dataType: .float)
654+
let tensor = Tensor.randn(like: other)
655+
XCTAssertEqual(tensor.shape, other.shape)
656+
XCTAssertEqual(tensor.count, other.count)
657+
}
640658
}

0 commit comments

Comments
 (0)